未验证 提交 fa4ae88a 编写于 作者: H huangjiyi 提交者: GitHub

update (#53146)

上级 47fa8066
...@@ -21,22 +21,6 @@ import warnings ...@@ -21,22 +21,6 @@ import warnings
import pandas as pd import pandas as pd
def preprocess_macro(file_content, processed_file_path):
if file_content is None:
return file_content
# comment out external macro
file_content = re.sub(r'#(include|pragma)', r'// \g<0>', file_content)
with open(processed_file_path, "w") as f:
f.write(file_content)
# expand macro and correct format
subprocess.run(
['g++', '-E', processed_file_path, '-o', processed_file_path]
)
subprocess.run(['clang-format', '-i', processed_file_path])
file_content = open(processed_file_path, "r").read()
return file_content
def search_pattern(pattern, file_content): def search_pattern(pattern, file_content):
if file_content is not None: if file_content is not None:
match_result = re.search(pattern, file_content) match_result = re.search(pattern, file_content)
...@@ -100,8 +84,9 @@ class KernelSignatureSearcher: ...@@ -100,8 +84,9 @@ class KernelSignatureSearcher:
self.func_signature_map[match_result[1]] = match_result[0] self.func_signature_map[match_result[1]] = match_result[0]
def search_kernel_registration(self, path): def search_kernel_registration(self, path):
self.processed_file_path = osp.join( self.tmp_file_path = osp.join(self.build_path, '.tmp_file.cc')
self.build_path, '.processed_file.cc' self.processed_file_path = self.tmp_file_path.replace(
'.tmp_file.cc', '.processed_file.cc'
) )
for file in os.listdir(path): for file in os.listdir(path):
file_path = osp.join(path, file) file_path = osp.join(path, file)
...@@ -113,6 +98,7 @@ class KernelSignatureSearcher: ...@@ -113,6 +98,7 @@ class KernelSignatureSearcher:
if re.match(r'\w+_kernel\.(cc|cu)', file): if re.match(r'\w+_kernel\.(cc|cu)', file):
self._search_kernel_registration(file_path, file) self._search_kernel_registration(file_path, file)
if osp.exists(self.processed_file_path): if osp.exists(self.processed_file_path):
os.remove(self.tmp_file_path)
os.remove(self.processed_file_path) os.remove(self.processed_file_path)
def _search_kernel_registration(self, file_path, file): def _search_kernel_registration(self, file_path, file):
...@@ -121,9 +107,7 @@ class KernelSignatureSearcher: ...@@ -121,9 +107,7 @@ class KernelSignatureSearcher:
# if some kernel registration is in macro, preprocess macro first # if some kernel registration is in macro, preprocess macro first
self.file_preprocessed = False self.file_preprocessed = False
if re.search(self.macro_kernel_reg_pattern, file_content): if re.search(self.macro_kernel_reg_pattern, file_content):
file_content = preprocess_macro( file_content = self.preprocess_macro(file_content)
file_content, self.processed_file_path
)
self.file_preprocessed = True self.file_preprocessed = True
# search kernel registration # search kernel registration
match_results = re.findall(self.kernel_reg_pattern, file_content) match_results = re.findall(self.kernel_reg_pattern, file_content)
...@@ -161,9 +145,7 @@ class KernelSignatureSearcher: ...@@ -161,9 +145,7 @@ class KernelSignatureSearcher:
return kernel_signature return kernel_signature
# expand macro and search again # expand macro and search again
if not self.file_preprocessed: if not self.file_preprocessed:
file_content = preprocess_macro( file_content = self.preprocess_macro(file_content)
file_content, self.processed_file_path
)
kernel_signature = search_pattern( kernel_signature = search_pattern(
target_kernel_signature_pattern, file_content target_kernel_signature_pattern, file_content
) )
...@@ -175,9 +157,7 @@ class KernelSignatureSearcher: ...@@ -175,9 +157,7 @@ class KernelSignatureSearcher:
if osp.exists(header_path): if osp.exists(header_path):
self.header_content = open(header_path, 'r').read() self.header_content = open(header_path, 'r').read()
if self.header_content is not None: if self.header_content is not None:
self.header_content = preprocess_macro( self.header_content = self.preprocess_macro(self.header_content)
self.header_content, self.processed_file_path
)
kernel_signature = search_pattern( kernel_signature = search_pattern(
target_kernel_signature_pattern, self.header_content target_kernel_signature_pattern, self.header_content
) )
...@@ -185,6 +165,21 @@ class KernelSignatureSearcher: ...@@ -185,6 +165,21 @@ class KernelSignatureSearcher:
return kernel_signature return kernel_signature
return None return None
def preprocess_macro(self, file_content):
if file_content is None:
return file_content
# comment out external macro
file_content = re.sub(r'#(include|pragma)', r'// \g<0>', file_content)
with open(self.tmp_file_path, "w") as f:
f.write(file_content)
# expand macro and correct format
subprocess.run(
['g++', '-E', self.tmp_file_path, '-o', self.processed_file_path]
)
subprocess.run(['clang-format', '-i', self.processed_file_path])
file_content = open(self.processed_file_path, "r").read()
return file_content
def get_kernel_signatures(): def get_kernel_signatures():
""" """
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册