提交 5b288467 编写于 作者: Q quyongxiu1

add cli and report

fix review pros

add readme
上级 98516a84
### Introduction
MindConverter is a tool that converting PyTorch scripts to MindSpore scripts. With minial manual editing and the guidance from conversion reports, users may easily migrate their model from PyTorch framework to MindSpore.
### System Requirements
* PyTorch v1.5.0
* MindSpore v0.2.0
### Installation
This tool is part of MindInsight and accessible to users after installing MindInsight, no extra installation is needed.
### Commandline Usage
Set the model scripts directory as the PYTHONPATH environment variable first:
```buildoutcfg
export PYTHONPATH=<model scripts dir>
```
mindconverter commandline usage:
```buildoutcfg
mindconverter [-h] [--version] --in_file IN_FILE [--output OUTPUT]
[--report REPORT]
MindConverter CLI entry point (version: 0.2.0)
optional arguments:
-h, --help show this help message and exit
--version show program's version number and exit
--in_file IN_FILE Specify path for script file.
--output OUTPUT Specify path for converted script file directory. Default
is output directory in the current working directory.
--report REPORT Specify report directory. Default is the current working
directorys
```
Usage example:
```buildoutcfg
export PYTHONPATH=~/my_pt_proj/models
mindconverter --in_file lenet.py
```
Since the conversion is not 100% flawless, we encourage users to checkout the reports when fixing issues of the converted scripts.
### Unsupported Situation #1
Classes and functions that can't be converted:
* The use of shape, ndim and dtype member of torch.Tensor.
* torch.nn.AdaptiveXXXPoolXd and torch.nn.functional.adaptive_XXX_poolXd()
* torch.nn.functional.Dropout
* torch.unsqueeze() and torch.Tensor.unsqueeze()
* torch.chunk() and torch.Tensor.chunk()
### Unsupported Situation #2
Subclassing from the subclasses of nn.Module
e.g. (code snip from torchvision,models.mobilenet)
```python
from torch import nn
class ConvBNReLU(nn.Sequential):
def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1):
padding = (kernel_size - 1) // 2
super(ConvBNReLU, self).__init__(
nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False),
nn.BatchNorm2d(out_planes),
nn.ReLU6(inplace=True)
)
```
### Unsupported Situation #3
Unconventional import naming
e.g.
```python
import torch.nn as mm
```
\ No newline at end of file
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Command module."""
import os
import sys
import argparse
import mindinsight
from mindinsight.mindconverter.converter import main
class FileDirAction(argparse.Action):
"""File directory action class definition."""
@staticmethod
def check_path(parser, values, option_string=None):
"""
Check argument for file path.
Args:
parser (ArgumentParser): Passed-in argument parser.
values (object): Argument values with type depending on argument definition.
option_string (str): Optional string for specific argument name. Default: None.
"""
outfile = values
if outfile.startswith('~'):
outfile = os.path.realpath(os.path.expanduser(outfile))
if not outfile.startswith('/'):
outfile = os.path.realpath(os.path.join(os.getcwd(), outfile))
if os.path.exists(outfile) and not os.access(outfile, os.R_OK):
parser.error(f'{option_string} {outfile} not accessible')
return outfile
def __call__(self, parser, namespace, values, option_string=None):
"""
Inherited __call__ method from argparse.Action.
Args:
parser (ArgumentParser): Passed-in argument parser.
namespace (Namespace): Namespace object to hold arguments.
values (object): Argument values with type depending on argument definition.
option_string (str): Optional string for specific argument name. Default: None.
"""
outfile_dir = self.check_path(parser, values, option_string)
if os.path.isfile(outfile_dir):
parser.error(f'{option_string} {outfile_dir} is a file')
setattr(namespace, self.dest, outfile_dir)
class OutputDirAction(argparse.Action):
"""File directory action class definition."""
def __call__(self, parser, namespace, values, option_string=None):
"""
Inherited __call__ method from argparse.Action.
Args:
parser (ArgumentParser): Passed-in argument parser.
namespace (Namespace): Namespace object to hold arguments.
values (object): Argument values with type depending on argument definition.
option_string (str): Optional string for specific argument name. Default: None.
"""
output = values
if output.startswith('~'):
output = os.path.realpath(os.path.expanduser(output))
if not output.startswith('/'):
output = os.path.realpath(os.path.join(os.getcwd(), output))
if os.path.exists(output):
if not os.access(output, os.R_OK):
parser.error(f'{option_string} {output} not accessible')
if os.path.isfile(output):
parser.error(f'{option_string} {output} is a file')
setattr(namespace, self.dest, output)
class InFileAction(argparse.Action):
"""Input File action class definition."""
def __call__(self, parser, namespace, values, option_string=None):
"""
Inherited __call__ method from argparse.Action.
Args:
parser (ArgumentParser): Passed-in argument parser.
namespace (Namespace): Namespace object to hold arguments.
values (object): Argument values with type depending on argument definition.
option_string (str): Optional string for specific argument name. Default: None.
"""
outfile_dir = FileDirAction.check_path(parser, values, option_string)
if not os.path.exists(outfile_dir):
parser.error(f'{option_string} {outfile_dir} not exists')
if not os.path.isfile(outfile_dir):
parser.error(f'{option_string} {outfile_dir} is not a file')
setattr(namespace, self.dest, outfile_dir)
class LogFileAction(argparse.Action):
"""Log file action class definition."""
def __call__(self, parser, namespace, values, option_string=None):
"""
Inherited __call__ method from FileDirAction.
Args:
parser (ArgumentParser): Passed-in argument parser.
namespace (Namespace): Namespace object to hold arguments.
values (object): Argument values with type depending on argument definition.
option_string (str): Optional string for specific argument name. Default: None.
"""
outfile_dir = FileDirAction.check_path(parser, values, option_string)
if os.path.exists(outfile_dir) and not os.path.isdir(outfile_dir):
parser.error(f'{option_string} {outfile_dir} is not a directory')
setattr(namespace, self.dest, outfile_dir)
def cli_entry():
"""Entry point for mindconverter CLI."""
permissions = os.R_OK | os.W_OK | os.X_OK
os.umask(permissions << 3 | permissions)
parser = argparse.ArgumentParser(
prog='mindconverter',
description='MindConverter CLI entry point (version: {})'.format(mindinsight.__version__))
parser.add_argument(
'--version',
action='version',
version='%(prog)s ({})'.format(mindinsight.__version__))
parser.add_argument(
'--in_file',
type=str,
action=InFileAction,
required=True,
help="""
Specify path for script file.
""")
parser.add_argument(
'--output',
type=str,
action=OutputDirAction,
default=os.path.join(os.getcwd(), 'output'),
help="""
Specify path for converted script file directory.
Default is output directory in the current working directory.
""")
parser.add_argument(
'--report',
type=str,
action=LogFileAction,
default=os.getcwd(),
help="""
Specify report directory. Default is the current working directory.
""")
argv = sys.argv[1:]
if not argv:
argv = ['-h']
args = parser.parse_args(argv)
else:
args = parser.parse_args()
mode = permissions << 6
os.makedirs(args.output, mode=mode, exist_ok=True)
os.makedirs(args.report, mode=mode, exist_ok=True)
_run(args.in_file, args.output, '', args.report)
def _run(in_files, out_dir, in_module, report):
"""
Run converter command.
Args:
in_files (str): The file path or directory to convert.
out_dir (str): The output directory to save converted file.
in_module (str): The module name to convert.
report (str): The report file path.
"""
files_config = {
'root_path': in_files if in_files else '',
'in_files': [],
'outfile_dir': out_dir,
'report_dir': report,
'in_module': in_module
}
if os.path.isfile(in_files):
files_config['root_path'] = os.path.dirname(in_files)
files_config['in_files'] = [in_files]
else:
for root_dir, _, files in os.walk(in_files):
for file in files:
files_config['in_files'].append(os.path.join(root_dir, file))
main(files_config)
......@@ -15,4 +15,4 @@
"""Create a logger."""
from mindinsight.utils.log import setup_logger
logger = setup_logger("mindconverter", "mindconverter")
logger = setup_logger("mindconverter", "mindconverter", console=False)
......@@ -222,6 +222,7 @@ class MappingHelper:
def convert(self, call_name_pt: str, args_str_pt: str):
"""
Convert code sentence to MindSpore code sentence.
Args:
call_name_pt (str): str of the call function, etc.
args_str_pt (str): str of args for function, which starts with '(' and end with ')'.
......@@ -336,11 +337,12 @@ def gen_explicit_map_one_delta(params_pt, args_pt, k_ms, k_pt):
def load_json_file(file_path):
"""
Load data from given json file path.
Args:
file_path (str): The file to load json data from.
Returns:
list, the list data stored in file_path.
list(str), the list data stored in file_path.
"""
with open(file_path, 'r', encoding='utf-8') as file:
info = json.loads(file.read())
......
......@@ -12,10 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""main module"""
import inspect
"""converter module"""
import copy
import importlib
import inspect
import os
import stat
......@@ -24,12 +24,23 @@ from mindinsight.mindconverter.config import NN_LIST
from mindinsight.mindconverter.config import ALL_TORCH_APIS
from mindinsight.mindconverter.config import ALL_2P_LIST
from mindinsight.mindconverter.config import UNSUPPORTED_WARN_INFOS
from mindinsight.mindconverter.config import ALL_UNSUPPORTED
from mindinsight.mindconverter.common.log import logger
from mindinsight.mindconverter.forward_call import ForwardCall
class Converter:
"""Convert class"""
convert_info = ''
flags = os.O_WRONLY | os.O_CREAT | os.O_EXCL
modes = stat.S_IWUSR | stat.S_IRUSR
def is_local_defined(obj, member):
@staticmethod
def is_local_defined(obj, member):
"""
Check if obj and member are both defined in the same source file.
Args:
obj (Union[object, module]): A module or a class.
member (func): A function of obj.
......@@ -40,10 +51,11 @@ def is_local_defined(obj, member):
srcfile = inspect.getsourcefile(obj)
return inspect.getsourcefile(member) == srcfile
def is_valid_module(obj, member):
@classmethod
def is_valid_module(cls, obj, member):
"""
Check if obj and member defined in same source file and member is inherited from torch.nn.Module.
Args:
obj (Union[object, module]): A module or a class.
member (func): A function.
......@@ -51,12 +63,21 @@ def is_valid_module(obj, member):
Returns:
bool, True or False.
"""
return inspect.isclass(member) and (member.__base__.__name__ == 'Module') and is_local_defined(obj, member)
def is_valid_function(obj, member):
if inspect.isclass(member):
is_subclass = member.__base__.__name__ in ['Module',
'Sequential',
'ModuleList',
'ModuleDict',
'ParameterList',
'ParameterDict']
return is_subclass and cls.is_local_defined(obj, member)
return False
@classmethod
def is_valid_function(cls, obj, member):
"""
Check if member is function and defined in the file same as obj.
Args:
obj (Union[object, module]: The obj.
member (func): The func.
......@@ -64,15 +85,16 @@ def is_valid_function(obj, member):
Returns:
bool, True or False.
"""
return inspect.isfunction(member) and is_local_defined(obj, member)
return inspect.isfunction(member) and cls.is_local_defined(obj, member)
def find_left_parentheses(string, right):
@staticmethod
def find_left_parentheses(string, right):
"""
Find index of the first left parenthesis.
Args:
string (str): A line of code.
right (int): Max index of string, same as `len(string) -1`.
right (int): The right index for string to find from.
Returns:
int, index of the first parenthesis.
......@@ -92,10 +114,11 @@ def find_left_parentheses(string, right):
return i
raise ValueError("{} should contain ()".format(string))
def find_right_parentheses(string, left):
@staticmethod
def find_right_parentheses(string, left):
"""
Find first index of right parenthesis which make all left parenthesis make sense.
Args:
string (str): A line of code.
left (int): Start index of string to find from.
......@@ -116,19 +139,19 @@ def find_right_parentheses(string, left):
return i
raise ValueError("{} should contain ()".format(string))
def get_call_name(code, end):
@staticmethod
def get_call_name(code, end):
"""
Traverse code in a reversed function from index end and get the call name and start index of the call name, if call
name not found, return a null character string and -1
Traverse code in a reversed function from index end and get the call name and start index of the call name,
if call name not found, return a null character string and -1
Args:
code (str): The str of code to find from.
end (int): Start index to find.
Returns:
str, founded api name if found, else a null character string.
int, start index of founded api name, -1 if api name not found
tuple(str, int), one is founded api name if found, else a null character string, the other is start index
of founded api name, -1 if api name not found
"""
stack = []
for i in range(end - 1, -1, -1):
......@@ -145,11 +168,10 @@ def get_call_name(code, end):
return code[i + 1:end], i + 1
return "", -1
def convert_api(code, start, api_name=""):
def convert_api(self, code, start, api_name=""):
"""
Convert api_name in code to MindSpore api with start as a start index, if api_name is a python api, code will not
convert.
Convert api_name in code to MindSpore api with start as a start index, if api_name is a python api,
code will not convert.
Args:
code (str): The str code to convert.
......@@ -159,11 +181,10 @@ def convert_api(code, start, api_name=""):
Returns:
str, the converted code.
int, index of converted api_name in code.
"""
# handle format like .shape(
if api_name.startswith('.'):
call_name, new_start = get_call_name(code, start)
call_name, new_start = self.get_call_name(code, start)
if start == -1 or call_name == "self":
return code, start + 1
else:
......@@ -174,7 +195,7 @@ def convert_api(code, start, api_name=""):
left = code.find("(", start)
if left == -1:
raise ValueError('"(" not found, {} should work with "("'.format(call_name))
right = find_right_parentheses(code, left)
right = self.find_right_parentheses(code, left)
end = right
expr = code[start:end + 1]
args_str = code[left:right + 1]
......@@ -187,12 +208,14 @@ def convert_api(code, start, api_name=""):
code = code[:new_start] + new_expr + code[end + 1:next_newline] + ("\n" * fill_num) + code[next_newline:]
else:
code = code[:new_start] + new_expr + ")" + ("\n" * fill_num) + code[end + 2:]
return code, start + len(map_helper.ms_api.name)
return code, start + len(map_helper.ms_api.name)
def find_api(code, i, is_forward):
@staticmethod
def find_api(code, i, is_forward):
"""
Find api name from code with a start index i, check api name ok with a is_forward condition.
Args:
code (str): The code from which to find api name.
i (int): The start index to find.
......@@ -212,10 +235,10 @@ def find_api(code, i, is_forward):
return api_name
return ""
def convert_function(fun_name, fun, is_forward):
def convert_function(self, fun_name, fun, is_forward):
"""
Convert a PyTorch function into MindSpore function.
Args:
fun_name (str): The str of function name.
fun (func): The function to convert.
......@@ -232,20 +255,24 @@ def convert_function(fun_name, fun, is_forward):
i = 0
while i < len(code):
api_name = find_api(code, i, is_forward)
api_name = self.find_api(code, i, is_forward)
if api_name:
line_no1 = line_no + code[:i].count('\n')
if api_name in ALL_MAPPING:
logger.info("Line %3d start converting API: %s", line_no1, api_name)
code, i = convert_api(code, i, api_name)
code, i = self.convert_api(code, i, api_name)
self.convert_info += "[Convert][Line{:3d}] {} is converted.\n".format(line_no1, api_name)
continue
if api_name in ALL_UNSUPPORTED:
warn_info = ". " + UNSUPPORTED_WARN_INFOS[api_name] if api_name in UNSUPPORTED_WARN_INFOS else ""
logger.warning("Line %3d: found unsupported API: %s%s", line_no1, api_name, warn_info)
self.convert_info += "[Unconvert][Line{:3d}] {} didn't convert{}\n".format(line_no1,
api_name, warn_info)
i += 1
return {code_saved: code} if code_saved != code else {}
def judge_forward(name, forward_list):
@staticmethod
def judge_forward(name, forward_list):
"""
Check if function is a forward function.
......@@ -261,10 +288,10 @@ def judge_forward(name, forward_list):
logger.debug("%s is a forward function", name)
return is_forward
def convert_module(module_name, module, forward_list):
def convert_module(self, module_name, module, forward_list):
"""
Convert a PyTorch module code into MindSpore module code.
Args:
module_name (str): The module's name.
module (module): The module to convert.
......@@ -278,15 +305,15 @@ def convert_module(module_name, module, forward_list):
mapped = {}
for name, member in inspect.getmembers(module):
if is_valid_function(module, member):
is_forward = judge_forward("{}.{}".format(module_name, name), forward_list)
mapped.update(convert_function(name, member, is_forward))
if self.is_valid_function(module, member):
is_forward = self.judge_forward("{}.{}".format(module_name, name), forward_list)
mapped.update(self.convert_function(name, member, is_forward))
return mapped
def get_mapping(import_mod, forward_list):
def get_mapping(self, import_mod, forward_list):
"""
Convert code of a module and get mapping of old code and convert code.
Args:
import_mod (module): The module to convert.
forward_list (set): A set of forward function.
......@@ -297,36 +324,40 @@ def get_mapping(import_mod, forward_list):
mapping = {}
tasks = []
for name, member in inspect.getmembers(import_mod):
if is_valid_module(import_mod, member):
if self.is_valid_module(import_mod, member):
_, line_no = inspect.getsourcelines(member)
tasks.append((line_no, convert_module, (name, member, forward_list)))
elif is_valid_function(import_mod, member):
tasks.append((line_no, self.convert_module, (name, member, forward_list)))
elif self.is_valid_function(import_mod, member):
_, line_no = inspect.getsourcelines(member)
is_forward = judge_forward("{}.{}".format(import_mod, name), forward_list)
tasks.append((line_no, convert_function, (name, member, is_forward)))
is_forward = self.judge_forward("{}.{}".format(import_mod, name), forward_list)
tasks.append((line_no, self.convert_function, (name, member, is_forward)))
tasks.sort()
for _, convert_fun, args in tasks:
mapping.update(convert_fun(*args))
return mapping
def convert(import_name, nn_module):
def convert(self, import_name, output_dir, report_dir):
"""
The entrance for convert a module's code, code converted will be write to file called out.py.
Convert a module's code, code converted will be save in output_dir, and a report will be save in report_dir.
Args:
import_name (str): The module from which to import the module to convert.
nn_module (str): Name of the module to convert.
output_dir (str): The path to save converted file.
report_dir (str): The path to save report file.
"""
logger.info("Start converting %s.%s", import_name, nn_module)
logger.info("Start converting %s", import_name)
self.convert_info += '[Start Convert]\nThe module is {}\n'.format(import_name)
import_mod = importlib.import_module(import_name)
forward_list = set()
srcfile = inspect.getsourcefile(import_mod)
logger.info("Script file is %s", srcfile)
forward_list = set(ForwardCall(srcfile).calls)
logger.debug("Forward_list: %s", forward_list)
# replace python function under nn.Modlue
mapping = get_mapping(import_mod, forward_list)
# replace python function under nn.Module
mapping = self.get_mapping(import_mod, forward_list)
code = inspect.getsource(import_mod)
for key, value in mapping.items():
......@@ -335,6 +366,12 @@ def convert(import_name, nn_module):
code = 'import mindspore.ops.operations as P\n' + code
code = 'import mindspore.nn as nn\n' + code
code = 'import mindspore\n' + code
self.convert_info += '||[Import Add] Add follow import sentences:\n'
self.convert_info += 'import mindspore.ops.operations as P\n'
self.convert_info += 'import mindspore.nn as nn\n'
self.convert_info += 'import mindspore\n\n'
code = code.replace('import torch', '# import torch')
code = code.replace('from torch', '# from torch')
code = code.replace('(nn.Module):', '(nn.Cell):')
......@@ -343,13 +380,82 @@ def convert(import_name, nn_module):
code = code.replace('(nn.Sequential)', '(nn.SequentialCell)')
code = code.replace('nn.init.', 'pass # nn.init.')
flags = os.O_WRONLY | os.O_CREAT | os.O_EXCL
modes = stat.S_IWUSR | stat.S_IRUSR
with os.fdopen(os.open('out.py', flags, modes), 'w') as file:
self.convert_info += '||[Import Annotated] Annotated follow import sentences:\n'
self.convert_info += 'import sentence on torch as follows are annotated:\n'
self.convert_info += 'import torch\n'
self.convert_info += 'from torch ...\n'
self.convert_info += '||[Explicit Convert] Module or function are explicitly converted as follows:\n'
self.convert_info += '[nn.Module] is converted to [nn.Cell]\n'
self.convert_info += '[forward] is converted to [construct]\n'
self.convert_info += '[nn.Linear] is converted to [nn.Dense]\n'
self.convert_info += '[nn.Sequential] is converted to [nn.SequentialCell]\n'
self.convert_info += '[nn.init] is not converted and annotated\n'
self.convert_info += '[Convert over]'
dest_file = os.path.join(output_dir, os.path.basename(srcfile))
with os.fdopen(os.open(dest_file, self.flags, self.modes), 'w') as file:
file.write(code)
logger.info("Convert success. Result is wrote to out.py\n")
logger.info("Convert success. Result is wrote to %s.", dest_file)
dest_report_file = os.path.join(report_dir,
'_'.join(os.path.basename(srcfile).split('.')[:-1]) + '_report.txt')
with os.fdopen(os.open(dest_report_file, self.flags, self.modes), 'a') as file:
file.write(self.convert_info)
logger.info("Convert report is saved in %s", dest_report_file)
if __name__ == '__main__':
def _get_name_ext(file):
"""
Split a file name in name and extension.
Args:
file (str): Full file path.
Returns:
tuple (str, str), name and extension.
"""
_, name = os.path.split(file)
return os.path.splitext(name)
convert('torchvision.models.resnet', 'resnet18')
def _path_split(file):
"""
Split a path in head and tail.
Args:
file (str): The file path.
Returns:
list[str], list of file tail
"""
file_dir, name = os.path.split(file)
if file_dir:
sep = file[len(file_dir)-1]
if file_dir.startswith(sep):
return file.split(sep)[1:]
return file.split(sep)
return [name]
def main(files_config):
"""
The entrance for converter, script files will be converted.
Args:
files_config (dict): The config of files which to convert.
"""
convert_ins = Converter()
root_path = files_config['root_path']
in_files = files_config['in_files']
for in_file in in_files:
in_file_split = _path_split(in_file[len(root_path):])
in_file_split[-1], _ = _get_name_ext(in_file_split[-1])
module_name = '.'.join(in_file_split)
convert_ins.convert(module_name, files_config['outfile_dir'], files_config['report_dir'])
in_module = files_config['in_module']
if in_module:
convert_ins.convert(in_module, files_config['outfile_dir'], files_config['report_dir'])
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Find out forward functions of script file"""
import ast
import os
class ForwardCall(ast.NodeVisitor):
"""
AST visitor that processes forward calls.
Find the sub functions called by the forward function in the script file.
"""
def __init__(self, filename):
self.filename = filename
self.module_name = os.path.basename(filename).replace('.py', '')
self.name_stack = []
self.forward_stack = []
self.calls = []
self.process()
def process(self):
"""Parse the python source file to find the forward functions."""
with open(self.filename, 'rt', encoding='utf-8') as file:
content = file.read()
self.visit(ast.parse(content, self.filename))
def get_current_namespace(self):
"""Get the namespace when visit the AST node"""
namespace = '.'.join(self.name_stack)
return namespace
@classmethod
def get_ast_node_name(cls, node):
"""Get AST node name."""
if isinstance(node, ast.Attribute):
return f'{cls.get_ast_node_name(node.value)}.{node.attr}'
if isinstance(node, ast.Name):
return node.id
return node
def visit_ClassDef(self, node):
"""Callback function when visit AST tree"""
self.name_stack.append(node.name)
self.generic_visit(node)
self.name_stack.pop()
def visit_FunctionDef(self, node):
"""Callback function when visit AST tree"""
func_name = f'{self.get_current_namespace()}.{node.name}'
is_in_chain = func_name in self.calls or node.name == 'forward'
if is_in_chain:
self.forward_stack.append(func_name)
if node.name == 'forward':
self.calls.append(func_name)
self.generic_visit(node)
if is_in_chain:
self.forward_stack.pop()
def visit_Call(self, node):
"""Callback function when visit AST tree"""
for arg in node.args:
self.visit(arg)
for kw in node.keywords:
self.visit(kw.value)
func_name = self.get_ast_node_name(node.func)
if isinstance(node.func, ast.Name):
if func_name not in ['super', 'str', 'repr']:
if self.forward_stack:
self.calls.append(func_name)
self.visit(node.func)
else:
if self.forward_stack:
if 'self' in func_name:
self.calls.append(f'{self.get_current_namespace()}.{func_name.split(".")[-1]}')
else:
self.calls.append(func_name)
self.visit(node.func)
......@@ -194,6 +194,7 @@ if __name__ == '__main__':
entry_points={
'console_scripts': [
'mindinsight=mindinsight.utils.command:main',
'mindconverter=mindinsight.mindconverter.cli:cli_entry',
],
},
python_requires='>=3.7',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册