提交 8ae4b36f 编写于 作者: R root

fix the data don't have the session issue

上级 a7c0e021
......@@ -13,12 +13,14 @@
# limitations under the License.
# ============================================================================
"""
Parser for AI CPU preprocess data.
The parser for AI CPU preprocess data.
"""
from mindinsight.profiler.common._utils import fwrite_format, get_file_join_name
from tabulate import tabulate
from mindinsight.profiler.common._utils import fwrite_format, get_file_join_name
from mindinsight.profiler.common.log import logger
_source_file_target = 'DATA_PREPROCESS.dev.AICPU'
_dst_file_title = 'title:DATA_PREPROCESS AICPU'
_dst_file_column_title = ['serial_number', 'node_name', 'total_time(us)', 'dispatch_time(us)',
......@@ -41,53 +43,59 @@ class DataPreProcessParser:
self._source_file_name = self._get_source_file()
def _get_source_file(self):
"""get log file name, which was created by ada service"""
"""Get log file name, which was created by ada service."""
return get_file_join_name(self._input_path, _source_file_target)
def execute(self):
"""execute the parser, get result data, and write it to the output file"""
ai_cpu_lines = list()
"""Execute the parser, get result data, and write it to the output file."""
if self._source_file_name is None:
logger.info("Did not find the aicpu profiling source file")
return
with open(self._source_file_name, 'rb') as ai_cpu_data:
ai_cpu_str = str(ai_cpu_data.read().replace(b'\n\x00', b' ___ ')
.replace(b'\x00', b' ___ '))[2:-1]
ai_cpu_lines = ai_cpu_str.split(" ___ ")
create_session_cnt = 0
for line in ai_cpu_lines:
if "Create session start" in line:
create_session_cnt += 1
start_idx = create_session_cnt * 2
ai_cpu_lines = ai_cpu_lines[start_idx:-3]
node_list = list()
ai_cpu_total_time_summary = 0
# node serial number
serial_number = 1
for i in range(len(ai_cpu_lines)-1):
node_line = ai_cpu_lines[i]
thread_line = ai_cpu_lines[i+1]
if "Node" in node_line and "Thread" in thread_line:
# get the node data from node_line
node_name = node_line.split(',')[0].split(':')[-1]
run_v2_start = node_line.split(',')[1].split(':')[-1]
compute_start = node_line.split(',')[2].split(':')[-1]
mercy_start = node_line.split(',')[3].split(':')[-1]
mercy_end = node_line.split(',')[4].split(':')[-1]
run_v2_end = node_line.split(',')[5].split(':')[-1]
# get total_time and dispatch_time from thread line
total_time = thread_line.split(',')[-1].split('=')[-1].split()[0]
dispatch_time = thread_line.split(',')[-2].split('=')[-1].split()[0]
node_data = [serial_number, node_name, total_time, dispatch_time, run_v2_start, compute_start,
mercy_start, mercy_end, run_v2_end]
node_list.append(node_data)
# calculate the total time
ai_cpu_total_time_summary += int(total_time)
# increase node serial number
serial_number += 1
elif "Node" in node_line and "Thread" not in thread_line:
node_name = node_line.split(',')[0].split(':')[-1]
logger.warning("The node:%s cannot find thread data", node_name)
node_list.append(["AI CPU Total Time(us):", ai_cpu_total_time_summary])
if start_idx > 0:
serial_number = 1
for line in ai_cpu_lines:
if "Node" in line:
node_name = line.split(',')[0].split(':')[-1]
run_v2_start = line.split(',')[1].split(':')[-1]
compute_start = line.split(',')[2].split(':')[-1]
mercy_start = line.split(',')[3].split(':')[-1]
mercy_end = line.split(',')[4].split(':')[-1]
run_v2_end = line.split(',')[5].split(':')[-1]
node_data = [serial_number, node_name, run_v2_start, compute_start,
mercy_start, mercy_end, run_v2_end]
node_list.append(node_data)
serial_number += 1
elif "Thread" in line:
# total_time and dispatch_time joins node list
total_time = line.split(',')[-1].split('=')[-1].split()[0]
dispatch_time = line.split(',')[-2].split('=')[-1].split()[0]
if node_list:
node_list[-1][2:2] = [total_time, dispatch_time]
ai_cpu_total_time_summary += int(total_time)
node_list.append(["AI CPU Total Time:", ai_cpu_total_time_summary])
if node_list:
fwrite_format(self._output_filename, data_source=_dst_file_title, is_print=True,
is_start=True)
fwrite_format(self._output_filename,
data_source=tabulate(node_list, _dst_file_column_title,
tablefmt='simple'),
is_start=True, is_print=True)
if node_list:
fwrite_format(self._output_filename, data_source=_dst_file_title, is_print=True,
is_start=True)
fwrite_format(self._output_filename,
data_source=tabulate(node_list, _dst_file_column_title,
tablefmt='simple'),
is_start=True, is_print=True)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册