提交 f1b64f74 编写于 作者: W WeibiaoYu

optimize profiling data analysis logic

上级 6f953c0c
......@@ -13,8 +13,8 @@
# limitations under the License.
# ============================================================================
"""The parser for hwts log file."""
import os
import struct
from tabulate import tabulate
from mindinsight.profiler.common._utils import fwrite_format, get_file_join_name
from mindinsight.profiler.common.log import logger
......@@ -30,8 +30,7 @@ class HWTSLogParser:
_source_file_target = 'hwts.log.data.45.dev.profiler_default_tag'
_dst_file_title = 'title:45 HWTS data'
_dst_file_column_title = ['Type', 'cnt', 'Core ID', 'Block ID', 'Task ID',
'Cycle counter', 'Stream ID']
_dst_file_column_title = 'Type cnt Core_ID Block_ID Task_ID Cycle_counter Stream_ID'
def __init__(self, input_path, output_filename):
self._input_path = input_path
......@@ -43,9 +42,11 @@ class HWTSLogParser:
file_name = get_file_join_name(self._input_path, self._source_file_target)
if not file_name:
msg = ("Fail to find hwts log file, under directory %s"
% self._input_path)
raise RuntimeError(msg)
data_path = os.path.join(self._input_path, "data")
file_name = get_file_join_name(data_path, self._source_file_target)
if not file_name:
msg = ("Fail to find hwts log file, under profiling directory")
raise RuntimeError(msg)
return file_name
......@@ -60,7 +61,7 @@ class HWTSLogParser:
content_format = ['QIIIIIIIIIIII', 'QIIQIIIIIIII', 'IIIIQIIIIIIII']
log_type = ['Start of task', 'End of task', 'Start of block', 'End of block', 'Block PMU']
result_data = []
result_data = ""
with open(self._source_flie_name, 'rb') as hwts_data:
while True:
......@@ -73,6 +74,7 @@ class HWTSLogParser:
byte_first_four = struct.unpack('BBHHH', line[0:8])
byte_first = bin(byte_first_four[0]).replace('0b', '').zfill(8)
ms_type = byte_first[-3:]
is_warn_res0_ov = byte_first[4]
cnt = int(byte_first[0:4], 2)
core_id = byte_first_four[1]
blk_id, task_id = byte_first_four[3], byte_first_four[4]
......@@ -80,22 +82,28 @@ class HWTSLogParser:
result = struct.unpack(content_format[0], line[8:])
syscnt = result[0]
stream_id = result[1]
result_data.append((log_type[int(ms_type, 2)], cnt, core_id, blk_id, task_id, syscnt, stream_id))
elif ms_type == '011': # log type 3
result = struct.unpack(content_format[1], line[8:])
syscnt = result[0]
stream_id = result[1]
result_data.append((log_type[int(ms_type, 2)], cnt, core_id, blk_id, task_id, syscnt, stream_id))
elif ms_type == '100': # log type 4
result = struct.unpack(content_format[2], line[8:])
stream_id = result[2]
result_data.append((log_type[int(ms_type, 2)], cnt, core_id, blk_id, task_id, total_cyc, stream_id))
if is_warn_res0_ov == '0':
syscnt = result[4]
else:
syscnt = None
else:
logger.info("Profiling: invalid hwts log record type %s", ms_type)
continue
if int(task_id) < 25000:
task_id = stream_id + "_" + task_id
result_data += ("%-14s %-4s %-8s %-9s %-8s %-15s %s\n" %(log_type[int(ms_type, 2)], cnt, core_id,
blk_id, task_id, syscnt, stream_id))
fwrite_format(self._output_filename, data_source=self._dst_file_title, is_start=True)
fwrite_format(self._output_filename, data_source=tabulate(result_data,
self._dst_file_column_title,
tablefmt='simple'))
fwrite_format(self._output_filename, data_source=self._dst_file_column_title)
fwrite_format(self._output_filename, data_source=result_data)
return True
......@@ -14,8 +14,6 @@
# ============================================================================
"""Op compute time files parser."""
import os
from tabulate import tabulate
from mindinsight.profiler.common._utils import fwrite_format
from mindinsight.profiler.common.exceptions.exceptions import ProfilerFileNotFoundException, \
ProfilerIOException
......@@ -25,7 +23,6 @@ from mindinsight.profiler.parser.container import HWTSContainer
TIMELINE_FILE_COLUMN_TITLE = 'op_name, stream_id, start_time(ms), duration(ms)'
class OPComputeTimeParser:
"""
Join hwts info and framework info, get op time info, and output to the result file.
......@@ -37,7 +34,8 @@ class OPComputeTimeParser:
"""
_dst_file_title = 'title:op compute time'
_dst_file_column_title = ['op_name', 'compute_time(ms)', 'stream_id']
_dst_file_column_title = 'op_name compute_time(ms) stream_id'
_dst_file_column_title += '\n------------ --------------- ---------'
def __init__(self, hwts_output_file, output_filename, op_task_info,
output_path, device_id):
......@@ -100,12 +98,15 @@ class OPComputeTimeParser:
op_name_count_dict, op_name_task_dict, op_name_start_time
)
result_data = []
result_data = ""
total_time = 0
for op_name, time in op_name_time_dict.items():
if op_name in op_name_stream_dict.keys():
stream_id = op_name_stream_dict[op_name]
avg_time = time / op_name_count_dict[op_name]
result_data.append([op_name, avg_time, stream_id])
total_time += avg_time
result_data += ("%s %s %s\n" %(op_name, str(avg_time), stream_id))
result_data += ("total op %s 0" %(str(total_time)))
timeline_data = []
for op_name, time in op_name_time_dict.items():
......@@ -130,23 +131,15 @@ class OPComputeTimeParser:
op name, average time, and stream id.
Args:
result_data (list): The metadata to be written into the file.
[
['op_name_1', 'avg_time_1', 'stream_id_1'],
['op_name_2', 'avg_time_2', 'stream_id_2'],
[...]
]
result_data (str): The metadata to be written into the file.
'op_name_1', 'avg_time_1', 'stream_id_1',
'op_name_2', 'avg_time_2', 'stream_id_2',
...
"""
result_data.sort(key=lambda x: x[0])
total_time = 0
for item in result_data:
total_time += item[1]
result_data.append(["total op", total_time, 0])
fwrite_format(self._output_filename, data_source=self._dst_file_title, is_start=True)
fwrite_format(self._output_filename, data_source=tabulate(result_data,
self._dst_file_column_title,
tablefmt='simple'))
fwrite_format(self._output_filename, data_source=self._dst_file_column_title)
fwrite_format(self._output_filename, data_source=result_data)
def _write_timeline_data_into_file(self, timeline_data):
"""
......
......@@ -89,9 +89,9 @@ class Profiler:
except ValueError as err:
logger.error("Profiling: fail to get context, %s", err)
if not dev_id:
if not dev_id or not dev_id.isdigit():
dev_id = os.getenv('DEVICE_ID')
if not dev_id:
if not dev_id or not dev_id.isdigit():
dev_id = "0"
logger.error("Fail to get DEVICE_ID, use 0 instead.")
......@@ -105,12 +105,12 @@ class Profiler:
self._container_path = os.path.join(self._base_profiling_container_path, dev_id)
data_path = os.path.join(self._container_path, "data")
if not os.path.exists(data_path):
os.makedirs(data_path)
os.makedirs(data_path, exist_ok=True)
self._output_path = validate_and_normalize_path(output_path,
'Profiler output path (' + output_path + ')')
self._output_path = os.path.join(self._output_path, "profiler")
if not os.path.exists(self._output_path):
os.makedirs(self._output_path)
os.makedirs(self._output_path, exist_ok=True)
os.environ['PROFILING_MODE'] = 'true'
os.environ['PROFILING_OPTIONS'] = 'training_trace:task_trace'
......@@ -121,8 +121,8 @@ class Profiler:
context.set_context(enable_profiling=True, profiling_options="training_trace:task_trace")
except ImportError:
logger.error("Profiling: fail to import context from mindspore.")
except ValueError as err:
logger.error("Profiling: fail to set context, %s", err.message)
except ValueError:
logger.error("Profiling: fail to set context enable_profiling")
os.environ['AICPU_PROFILING_MODE'] = 'true'
os.environ['PROFILING_DIR'] = str(self._container_path)
......@@ -162,8 +162,7 @@ class Profiler:
job_id = self._get_profiling_job_id()
if not job_id:
msg = ("Fail to get profiling job, please check whether job dir was generated under path %s" \
% PROFILING_LOG_BASE_PATH)
msg = ("Fail to get profiling job, please check whether job dir was generated")
raise RuntimeError(msg)
logger.info("Profiling: job id is %s ", job_id)
......@@ -296,6 +295,13 @@ class Profiler:
"""Disable the profiling collection service, called after training."""
os.environ['PROFILING_MODE'] = str("false")
try:
import mindspore.context as context
context.set_context(enable_profiling=False)
except ImportError:
logger.error("Profiling: fail to import context from mindspore.")
except ValueError:
logger.error("Profiling: fail to set context enable_profiling")
def _get_profiling_job_id(self):
"""Get profiling job id, which was generated by ada service.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册