diff --git a/mindinsight/mindconverter/ast_edits.py b/mindinsight/mindconverter/ast_edits.py new file mode 100644 index 0000000000000000000000000000000000000000..b96cc350c385551b008cb5571afbaac7a2d847cd --- /dev/null +++ b/mindinsight/mindconverter/ast_edits.py @@ -0,0 +1,579 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless REQUIRED by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Convert for Python scripts according API mapping information.""" + +import ast +import logging +import re +from enum import Enum + +import pasta +from pasta.augment import import_utils + +from mindinsight.mindconverter.code_analysis import CodeAnalyzer +from mindinsight.mindconverter.code_analysis import APIAnalysisSpec +from mindinsight.mindconverter.config import ALL_MAPPING +from mindinsight.mindconverter.config import NN_LIST +from mindinsight.mindconverter.config import ALL_TORCH_APIS +from mindinsight.mindconverter.config import ALL_2P_LIST +from mindinsight.mindconverter.config import UNSUPPORTED_WARN_INFOS +from mindinsight.mindconverter.config import ALL_UNSUPPORTED +from mindinsight.mindconverter.common.log import logger +from mindinsight.mindconverter.common.exceptions import NodeTypeNotSupport +from mindinsight.mindconverter.forward_call import ForwardCall + +LOG_FMT_CONVERT = "[Convert] '%s' is converted to '%s'." +LOG_FMT_NOT_CONVERT = "[UnConvert] '%s' didn't convert. %s" +LOG_FMT_PROMPT_INFO = "[INFO] %s" +LOG_SUGGESTION_MANUAL_CONVERT = "Please manual convert the code, along with the code associated with it." + + +class ApiMatchingEnum(Enum): + """Node edge type enum.""" + NOT_API = 'not an api name' + API_INFER = 'infer api name to map' + API_STANDARD = 'api name in the correct format' + API_FOUND = 'found an api name in api list' + API_MATCHED = 'api is matched to map' + + +class _ConvertReport: + """Report log of converting source code.""" + + def __init__(self, is_stub=False): + self._is_stub = is_stub + self._max_line = 0 + self._log = [] # report log, type is (severity, line, col, msg) + + def _add_log(self, severity, line, col, msg): + """Add log.""" + if self._is_stub: + return + if isinstance(line, int) and isinstance(col, int): + self._log.append((severity, line, col, msg)) + if self._max_line < line: + self._max_line = line + + def info(self, line, col, msg): + """Interface to add infer log""" + self._add_log(logging.INFO, line, col, msg) + + def warning(self, line, col, msg): + """Interface to add warning log""" + self._add_log(logging.WARNING, line, col, msg) + + def get_logs(self): + """Get convert logs""" + logs = [] + # sort rule: line * self._max_line + col + self._log.sort(key=lambda log: log[1] * self._max_line + log[2]) + for log_info in self._log: + log_info = "line %d:%d: %s" % (log_info[1], log_info[2], log_info[3]) + logs.append(log_info) + return logs + + +class _LineColEditVisitor(ast.NodeVisitor): + """ + Update line number and col offset of ast node. + + Use the line and column number of the original code to update + the line and column number of the new code replaced with the original code. + """ + + class _NodeInfo: + """NodeInfo class definition.""" + def __init__(self, node): + self.node = node + self.call_list = [] # Used to save all ast.Call node in self._node + + def __init__(self): + self._dst_node_info = None + self._src_node_info = None + self._visiting = self._src_node_info # Used to point to the visiting node + + def update(self, replace_with_node, src_node): + """Update the line and column number of the new code replaced with the original code.""" + replace_with_node.lineno = src_node.lineno + replace_with_node.col_offset = src_node.col_offset + self._dst_node_info = self._NodeInfo(replace_with_node) + self._src_node_info = self._NodeInfo(src_node) + self._visiting = self._src_node_info + self.visit(self._visiting.node) + + self._visiting = self._dst_node_info + self.visit(self._visiting.node) + + self._update_line_col() + + def visit_Call(self, node): + """Callback function when visit AST tree""" + self._visiting.call_list.append(node) + self.generic_visit(node) + + def _update_line_col(self): + """Update the line and column number information for all ast.Call node.""" + dst_call_list = list(self._dst_node_info.call_list) + src_call_list = list(self._src_node_info.call_list) + len_diff = len(dst_call_list) - len(src_call_list) + + # After MindSpore api replaces Torch api, more calls are generated. + # For example, out.view() is replaced with P.Reshape()(out). + # out.view() has only one call, but P.Reshape()(out) has two calls. + # To match the replaced calls, the calls of out.view is padded to the same quantity. + if len_diff > 0: + src_call_list = [src_call_list[0]] * len_diff + src_call_list + + for dst_call, src_call in zip(dst_call_list, src_call_list): + dst_call.lineno = src_call.lineno + dst_call.col_offset = src_call.col_offset + + if not dst_call.args: + continue + + # When out.size().view(1, ...) transforms to P.Reshape()(out.size(), 1, ...), + # in this case, the column of parameter out.size() will be bigger than the following parameters. + # To ensure the sequence of parameters, adjust the column of the second parameter. + args = [] + for arg in dst_call.args: + if self._check_arg2update(arg): + args.append(arg) + for arg in args: + arg.lineno = dst_call.lineno + arg.col_offset += dst_call.col_offset + + @staticmethod + def _check_arg2update(arg): + # Only the col_offset of the first line code is re-counted, needs to be corrected. + # When the arg is a function call, its col_offset is handled separately. + if not isinstance(arg, ast.Call) and arg.lineno == 1: + return True + return False + + +class AstEditVisitor(ast.NodeVisitor): + """AST Visitor that process function calls. + + Converts function calls from torch api to MindSpore api using api mapping information. + """ + + def __init__(self): + self._process_log = _ConvertReport() + self._tree = None + self._code_analyzer = None + self._stack = [] # Used to easily access the parent node + self._forward_list = {} + self._is_forward_function = False # Used to allow access the visiting function forward attribute + self._new_call_nodes = [] # Used to save new ast.call nodes + + def process(self, ast_tree): + """ + Convert source code to MindSpore code. + + Args: + ast_tree (AST): The root node of the source code. + """ + self.__init__() + self._tree = ast_tree + self._code_analyzer = CodeAnalyzer() + self._code_analyzer.process(self._tree) + + self._forward_list = ForwardCall(self._tree).calls + # replace python function under nn.Module + self._convert_api() + + # replace external reference statements + self._convert_external_reference() + + def get_logs(self): + """Get conversion report.""" + return self._process_log.get_logs() + + def _convert_cell(self, cell_scope): + """ + Convert a PyTorch Module class into MindSpore Cell class. + + Args: + cell_scope (pasta.base.Scope): The network class definition node inherits from torch.nn.Module. + """ + cell_ast_node = cell_scope.node + line_no = cell_ast_node.lineno + logger.info("Line %3d: start converting nn.Module %s", line_no, self._code_analyzer.get_name(cell_ast_node)) + + class_elements = self._code_analyzer.network_definitions()['cell'] + # step1. update function definition + for func_scope in class_elements.get(cell_scope, []): + self._update_function_def(func_scope) + + # step2. update base name of class + self._update_base_name(cell_scope) + + def _update_base_name(self, class_def_scope): + """ + Update base name of class. + + Args: + class_def_scope (ast.ClassDef): Class definition node. + """ + base_name_mapping = APIAnalysisSpec.base_name_mapping + class_def_node = class_def_scope.node + base_class_nodes = class_def_scope.node.bases + # update base class name + for base_class_node in base_class_nodes: + base_name = base_class_node.attr + if base_name in APIAnalysisSpec.get_network_base_class_names(): + old_code = pasta.dump(base_class_node) + if base_name in base_name_mapping: + new_code = 'nn.' + base_name_mapping[base_class_node.attr] + new_node = pasta.parse(new_code) + pasta.ast_utils.replace_child(class_def_node, base_class_node, new_node) + self._process_log.info(base_class_node.lineno, base_class_node.col_offset, LOG_FMT_CONVERT % + (old_code, new_code)) + else: + self._process_log.info(base_class_node.lineno, base_class_node.col_offset, LOG_FMT_NOT_CONVERT % + (old_code, '')) + + def _update_function_def(self, func_scope): + """ + Convert a PyTorch function into MindSpore function. + + Args: + func_scope (pasta.base.scope.Scope): The node scope of function definition. + """ + is_forward = self._judge_forward(func_scope) + # step1. convert the content of the function. + self._convert_function(func_scope, is_forward) + + # step2. replace function name if name is forward + func_ast_node = func_scope.node + old_func_name = 'forward' + new_func_name = 'construct' + if func_ast_node.name == old_func_name: + func_ast_node.name = new_func_name + self._process_log.info(func_ast_node.lineno, func_ast_node.col_offset, + LOG_FMT_CONVERT % (old_func_name, new_func_name)) + + def _convert_api(self): + """Convert PyTorch api call to MindSpore api call in a function.""" + tasks = [] + + convert_elements = self._code_analyzer.network_definitions() + for func_node_scope in convert_elements.get("functions", []): + is_forward = self._judge_forward(func_node_scope) + tasks.append((self._convert_function, (func_node_scope, is_forward))) + for class_scope in convert_elements.get("cell", []).keys(): + tasks.append((self._convert_cell, (class_scope,))) + + for convert_fun, args in tasks: + convert_fun(*args) + + def _convert_external_reference(self): + """Convert import statements.""" + name_replace = APIAnalysisSpec.import_name_mapping + replace_imports = list(name_replace.values()) + + for ref_info in self._code_analyzer.external_references.values(): + external_ref_info = ref_info['external_ref_info'] + parent_node = ref_info['parent_node'] + if parent_node is None: + continue + code = pasta.dump(parent_node) + if external_ref_info.name in APIAnalysisSpec.get_convertible_external_names(): + external_ref_info = ref_info['external_ref_info'] + if external_ref_info.name in name_replace.keys(): + import_utils.remove_import_alias_node(self._code_analyzer.root_scope, external_ref_info.node) + replace_info = name_replace[external_ref_info.name] + new_ref_name = replace_info[1] + new_external_name = replace_info[0] + if new_ref_name: + new_code = f'import {new_external_name} as {new_ref_name}' + else: + new_code = f'import {new_external_name}' + + self._process_log.info(parent_node.lineno, parent_node.col_offset, LOG_FMT_CONVERT % + (code.strip(), new_code.strip())) + elif external_ref_info.name.startswith('torch.'): + self._process_log.warning(parent_node.lineno, parent_node.col_offset, LOG_FMT_NOT_CONVERT % + (code.strip(), LOG_SUGGESTION_MANUAL_CONVERT)) + else: + pass + + # Insert import in reverse order, display in forward order. + for idx in range(len(replace_imports) - 1, -1, -1): + replace_import = replace_imports[idx] + if replace_import[1]: + self._add_import(name_to_import=replace_import[0], as_name=replace_import[1]) + else: + self._add_import(name_to_import=replace_import[0]) + + def _add_import(self, name_to_import, as_name=None): + """ + Adds an import to the ast tree. + + Args: + name_to_import: (string) The absolute name to import. + as_name: (string) The alias for the import ("import name_to_import as asname") + """ + new_alias = ast.alias(name=name_to_import, asname=as_name) + import_node = ast.Import(names=[new_alias]) + + # Insert the node at the top of the module + self._tree.body.insert(1 if pasta.base.ast_utils.has_docstring(self._tree) else 0, import_node) + + def _convert_function(self, func_scope, is_forward): + """ + Convert a PyTorch function into MindSpore function. + + Args: + func_scope (pasta.base.scope.Scope): The node scope of function definition. + is_forward (boolean): If the function is defined in forward function in nn.Module in torch. + """ + func_ast_node = func_scope.node + line_no = func_ast_node.lineno + logger.info("Line %3d: start converting function %s()", line_no, func_ast_node.name) + + parent = func_scope.parent_scope.node + self._stack.clear() + self._new_call_nodes.clear() + if parent: + self._stack.append(parent) + + self._is_forward_function = is_forward + self.visit(func_scope.node) + + def _judge_forward(self, func_scope): + """ + Check if function is a forward function. + + Args: + func_scope (pasta.base.scope.Scope): The node scope of function definition. + + Returns: + boolean, True or False + """ + is_forward = func_scope.node in self._forward_list.values() + if is_forward: + logger.debug("%s is a forward function", self._code_analyzer.get_name(func_scope)) + return is_forward + + # Overridden to maintain stack information to access parent node + def visit(self, node): + """Visit a ast tree.""" + self._stack.append(node) + super(AstEditVisitor, self).visit(node) + self._stack.pop() + + def _mapping_standard_api_name(self, api_name): + """Get mapping from external reference name to standard external reference name""" + standard_name = api_name + if not self._code_analyzer.is_standard_external_ref: + # key is real ref name, value is standard ref name. + mapping_names = self._mapping_standard_external_ref() + api_name_parts = api_name.split('.') + api_name_parts[0] = mapping_names.get(api_name_parts[0], api_name_parts[0]) + standard_name = '.'.join(api_name_parts) + return standard_name + + def _infer_api_name(self, call_func_node, check_context=True): + """Infer the call name. + + Examples: + 1. nn.Sequential inferred to nn.Sequential + 2. mmm.size inferred to .size if import torch.nn as nn + 3. mmm.size inferred to mmm.size if import torch.nn as mmm + """ + match_case = ApiMatchingEnum.NOT_API + api_name = None + call_name = pasta.dump(call_func_node) + + is_include_sub_call = self._is_include_sub_call(call_func_node) + if is_include_sub_call: + name_attributes = call_name.rsplit('.', 1) + else: + name_attributes = call_name.split('.') + + # rewritten external module name + # e.g., mm.ReLU will be written to nn.ReLU if 'import torch.nn as mm' in script. + if check_context and not self._code_analyzer.is_standard_external_ref: + standard_name = self._mapping_standard_api_name(name_attributes[0]) + else: + standard_name = name_attributes[0] + + if standard_name in ["nn", "F", "torch"]: + match_case = ApiMatchingEnum.API_STANDARD + api_name = call_name + else: + # only infer function for tensor object. + # e.g., api_call_name is out.view, .view is an api name for out which is maybe a tensor object. + # e.g., 'xxxx'.size can be not inferred to .size, because string is not a tensor object. + first_name = standard_name.split('.')[0] + if not re.search(r'\W', first_name) and len(name_attributes) > 1: + api_name = '.' + name_attributes[-1] + match_case = ApiMatchingEnum.API_INFER + return api_name, match_case + + @staticmethod + def _is_include_sub_call(call_func_node): + """"Inspect a sub call in call expression. + + Examples: + 1. nn.functional.relu() return False + 2. nn.functional.relu(out).size() return True. nn.functional.relu(out) is sub call. + 3. nn.functional.relu(out=out.size()).size() return False. out.size() is not sub call of argument. + """ + is_include_call = False + try: + sub_node = call_func_node + while sub_node and not isinstance(sub_node, ast.Call): + sub_node = sub_node.value + if isinstance(sub_node, ast.Call): + is_include_call = True + except AttributeError: + is_include_call = False + return is_include_call + + def match_api(self, call_func_node, is_forward): + """ + Check api name to convert, check api name ok with a is_forward condition. + + Args: + call_func_node (ast.Attribute): The call.func node. + is_forward (bool): whether api belong to forward. + + Returns: + str, the standard api name used to match. + ApiMappingEnum, the match result. + """ + api_name, match_case = self._infer_api_name(call_func_node) + api_call_name = pasta.dump(call_func_node) + is_tensor_obj_call = False + if api_name != api_call_name: + is_tensor_obj_call = True + + standard_api_call_name = api_name + + # rewritten external module name + # e.g., mm.ReLU will be written to nn.ReLU if 'import torch.nn as mm' in script. + if not is_tensor_obj_call and not self._code_analyzer.is_standard_external_ref: + standard_api_call_name = self._mapping_standard_api_name(api_name) + + if standard_api_call_name in ALL_TORCH_APIS: + match_case = ApiMatchingEnum.API_FOUND + if (not is_forward and standard_api_call_name in NN_LIST) or \ + (is_forward and standard_api_call_name in ALL_2P_LIST): + match_case = ApiMatchingEnum.API_MATCHED + + return standard_api_call_name, match_case + + def mapping_api(self, call_node, check_context=True): + """ + Convert api_name in code to MindSpore api, if api_name is a python api, code will not convert. + + If do not check context of the script, the code represented by the node must be written in the standard way. + + Args: + call_node (ast.Call): The ast node to convert. + check_context (boolean): If True, the code context will be checked. Default is True. + + Returns: + str, the converted code. + """ + if not isinstance(call_node, ast.Call): + raise NodeTypeNotSupport("It is not ast.Call node.") + code = pasta.dump(call_node) + api_call_name = pasta.dump(call_node.func) + if api_call_name.startswith('self.'): + return code + + # find full api expected to be converted. eg:expr="nn.Conv2d(1,2,3)" args_str="(1,2,3)" + args_str = code[len(api_call_name):].strip() + + try: + api_name, _ = self._infer_api_name(call_node.func, check_context) + standard_api_call_name = api_call_name + if api_name != api_call_name: + # api name .view inferred from out.view, split tensor object name is out + tensor_obj_name = api_call_name[:-len(api_name)] + map_helper = ALL_MAPPING[api_name] + new_code = map_helper.convert(tensor_obj_name, args_str) + else: + # change to external ref name + # e.g., mm.ReLU will be changed to nn.ReLU if 'import torch.nn as mm' in script. + if check_context and not self._code_analyzer.is_standard_external_ref: + standard_api_call_name = self._mapping_standard_api_name(api_name) + + map_helper = ALL_MAPPING[standard_api_call_name] + new_code = map_helper.convert(standard_api_call_name, args_str) + except KeyError: + return code + + return new_code + + def visit_Call(self, node): + """Callback function when visit AST tree""" + code = pasta.dump(node) + api_name = pasta.dump(node.func) + + # parent node first call is equal to this node, skip when parent node is replaced. + for parent_node in self._stack[:-1]: + if parent_node in self._new_call_nodes and pasta.dump(parent_node).startswith(api_name): + return + parent = self._stack[-2] + new_node = None + matched_api_name, match_case = self.match_api(node.func, self._is_forward_function) + if match_case in [ApiMatchingEnum.API_INFER, ApiMatchingEnum.API_MATCHED]: + if matched_api_name in ALL_MAPPING: + logger.info("Line %3d start converting API: %s", node.lineno, api_name) + new_code = self.mapping_api(node) + if new_code != code: + new_node = pasta.parse(new_code).body[0].value + # find the first call name + new_api_name = new_code[:new_code.find('(')] + self._process_log.info(node.lineno, node.col_offset, LOG_FMT_CONVERT % (api_name, new_api_name)) + if matched_api_name in ALL_UNSUPPORTED: + warn_info = UNSUPPORTED_WARN_INFOS.get(api_name, '') + logger.warning("Line %3d: found unsupported API: %s%s", node.lineno, api_name, warn_info) + self._process_log.warning(node.lineno, node.col_offset, LOG_FMT_NOT_CONVERT % (api_name, warn_info)) + + elif match_case in [ApiMatchingEnum.API_STANDARD, ApiMatchingEnum.API_FOUND]: + self._process_log.warning(node.lineno, node.col_offset, LOG_FMT_NOT_CONVERT % (api_name, '')) + else: + pass + + if parent and new_node: + update_line_col = _LineColEditVisitor() + update_line_col.update(new_node, node) + pasta.ast_utils.replace_child(parent, node, new_node) + self._new_call_nodes.append(new_node) + + node = new_node + self._stack[-1] = node + try: + self.generic_visit(node) + except Exception: + logger.error('original code:%s, new code:%s', code, new_code, exc_info=True) + raise + + def _mapping_standard_external_ref(self): + """Obtain the mapping dict of mapping the external references to standard external references.""" + renames = {} + external_refs = self._code_analyzer.external_references + for ref_name, ref_info in external_refs.items(): + external_ref_info = ref_info['external_ref_info'] + if ref_name != 'nn' and external_ref_info.name == 'torch.nn': + renames[ref_name] = 'nn' + elif ref_name != 'F' and external_ref_info.name == 'torch.nn.functional': + renames[ref_name] = 'F' + return renames diff --git a/mindinsight/mindconverter/cli.py b/mindinsight/mindconverter/cli.py index 7e0a7d67b55a1cbba8f59c2a100f462ee110c56a..9c49d65d81c447475c81f5eda42bae06aecc5000 100644 --- a/mindinsight/mindconverter/cli.py +++ b/mindinsight/mindconverter/cli.py @@ -186,25 +186,23 @@ def cli_entry(): mode = permissions << 6 os.makedirs(args.output, mode=mode, exist_ok=True) os.makedirs(args.report, mode=mode, exist_ok=True) - _run(args.in_file, args.output, '', args.report) + _run(args.in_file, args.output, args.report) -def _run(in_files, out_dir, in_module, report): +def _run(in_files, out_dir, report): """ Run converter command. Args: in_files (str): The file path or directory to convert. out_dir (str): The output directory to save converted file. - in_module (str): The module name to convert. report (str): The report file path. """ files_config = { 'root_path': in_files if in_files else '', 'in_files': [], 'outfile_dir': out_dir, - 'report_dir': report, - 'in_module': in_module + 'report_dir': report } if os.path.isfile(in_files): files_config['root_path'] = os.path.dirname(in_files) diff --git a/mindinsight/mindconverter/code_analysis.py b/mindinsight/mindconverter/code_analysis.py new file mode 100644 index 0000000000000000000000000000000000000000..adbbbce6924b3a354f06e097847e27d4dc7b6309 --- /dev/null +++ b/mindinsight/mindconverter/code_analysis.py @@ -0,0 +1,399 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless REQUIRED by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""code analysis module""" +import ast + +import pasta +from pasta.base import scope + +from mindinsight.mindconverter.common.exceptions import ScriptNotSupport + + +class APIAnalysisSpec: + """API analysis specifications""" + + import_name_mapping = {'torch': ['mindspore', None], + 'torch.nn': ['mindspore.nn', 'nn'], + 'torch.nn.functional': ['mindspore.ops.operations', 'P']} + + base_name_mapping = {'Module': 'Cell', + 'Sequential': 'SequentialCell' + } + + @classmethod + def get_convertible_external_names(cls): + """ + Obtain the convertible external names. + + The external name is the full dotted name being referenced. + """ + return cls.import_name_mapping.keys() + + @staticmethod + def get_network_base_class_names(): + """Obtain the base names which network class base from""" + return ['Module', + 'Sequential', + 'ModuleList', + 'ModuleDict', + 'ParameterList', + 'ParameterDict'] + + @staticmethod + def check_external_alias_ref(ref_name, external_name): + """ + Check 'import as' is standard. + + Standard references are follow: + import torch.nn as nn + import torch.nn.functional as F + + Args: + ref_name (str): The name that refers to the external_name. + external_name (str): The full dotted name being referenced. For examples: + 1. 'import torch.nn as nn', torch.nn is external_name, nn is ref_name. + 2. 'from torch import nn as mm, torch.nn is external_name, mm is ref_name which is not a standard name. + + Returns: + boolean, True if ref_name is standard else False. + """ + if ref_name != 'nn' and external_name == 'torch.nn': + is_standard = False + elif ref_name != 'F' and external_name == 'torch.nn.functional': + is_standard = False + else: + is_standard = True + + return is_standard + + +class CodeAnalyzer(ast.NodeVisitor): + """Code analyzer that analyzes PyTorch python script by AST Visitor. + + CodeAnalyzer find the codes that need to be converted to MindSpore, + and provides the attributes related to the codes. + """ + + def __init__(self): + self._stack = [] # Used to easily access the parent node + self._external_references = {} + self._is_standard_external_ref = True + self._root_scope = None + # Used to save functions that need to be converted, value type is pasta.base.scope.Scope + self._network_functions = [] + + # Used to easily trace the function node + self._functions_stack = [] + + # key type is pasta.base.scope.Scope, value type is list + self._network_classes = {} + + @property + def root_scope(self): + """The root scope of the python script code.""" + return self._root_scope + + @property + def is_standard_external_ref(self): + """Obtain whether the result is a standard external reference.""" + return self._is_standard_external_ref + + @property + def external_references(self): + """Obtain all external references in the analyzed code.""" + return self._external_references + + def network_definitions(self): + """Obtain the network definitions which need to be converted.""" + return {"functions": self._network_functions, + "cell": self._network_classes} + + def process(self, ast_tree): + """ + Start to analyze the code. + + Args: + ast_tree (AST): The root node of the source code. + """ + self.__init__() + self._root_scope = scope.analyze(ast_tree) + self._pre_process() + self.visit(ast_tree) + if not self._network_classes: + msg = "model definition not be found." + raise ScriptNotSupport(msg) + + @staticmethod + def _check_external_standard(external_refs): + """Check whether all external references are standard.""" + is_standard = True + for external_name, external_ref_info in external_refs.items(): + is_standard = APIAnalysisSpec.check_external_alias_ref(external_name, external_ref_info.name) + if not is_standard: + break + return is_standard + + def _is_base_from_cell(self, node): + """ + Check whether the node bases from cell classes which are defined in APIAnalysisSpec. + + Args: + node (ast.ClassDef): The node which is a class definition. + + Returns: + boolean, True if the check result is Passed else False. + """ + if self._is_ref_convertible_imports(node): + whole_name = self._get_whole_name(node) + if whole_name.split('.')[-1] in APIAnalysisSpec.get_network_base_class_names(): + return True + return False + + def _pre_process(self): + """Preprocessor checks the code before analyzing.""" + is_torch = False + + # check whether the code imports torch. + for ref_name in self._root_scope.external_references.keys(): + if ref_name.split('.')[0] in APIAnalysisSpec.get_convertible_external_names(): + is_torch = True + break + if not is_torch: + msg = "The source code does not import torch, model definition can not be found." + raise ScriptNotSupport(msg) + + # Find out external reference in the code and save it. + external_refs = self._analyze_import_references(self._root_scope) + self._is_standard_external_ref = self._check_external_standard(external_refs) + self._check_external_standard(external_refs) + for external_name, external_ref_info in external_refs.items(): + self._external_references.update({ + external_name: { + 'external_ref_info': external_ref_info, + 'parent_node': None + } + }) + + @staticmethod + def _analyze_import_references(root_scope): + """Find out all references from the import statements.""" + external_name_ref = {} + for node_references in root_scope.external_references.values(): + for node_ref in node_references: + if node_ref.name_ref: + # (from)import alias, node_ref.name_ref.id is alias name + if node_ref.name_ref.definition.asname == node_ref.name_ref.id: + external_name_ref[node_ref.name_ref.id] = node_ref + # import without alias, node_ref.name_ref.definition.asname is None. + # e.g., import a.b.c, reference maybe is a, a.b or a.b.c in the root_scope.external_references. + # The reference a.b.c is really wanted. + elif node_ref.name_ref.definition.name == node_ref.name_ref.id: + external_name_ref[node_ref.name_ref.id] = node_ref + else: + pass + + return external_name_ref + + def visit(self, node): + """Overridden visit of the base class to maintain stack information to access parent node.""" + self._stack.append(node) + super(CodeAnalyzer, self).visit(node) + self._stack.pop() + + @staticmethod + def _get_full_name(node): + """Get the full name of the node.""" + if not isinstance(node, (ast.Attribute, ast.Name)): + return None + return pasta.dump(node) + + def _get_whole_name(self, node): + """ + Get the whole name of the node. + + For example, nn.Module is spliced two nodes, nn node and Module node. + When visit ast nodes, + Module node is first visited, the full name is the same as the whole name, that is nn.Module. + And then nn node is visited, the full name is nn, the whole name is nn.Module. + """ + full_name = self._get_full_name(node) + if not full_name: + return None + + # node is in stack top pos + if node is self._stack[-1]: + parent_index = -1 + while isinstance(self._stack[parent_index], ast.Attribute): + parent_index -= 1 + + whole_name = self._get_full_name(self._stack[parent_index]) + else: + whole_name = full_name + return whole_name + + def _is_ref_convertible_imports(self, node): + """Check whether the node references convertible imports.""" + check_result = False + whole_name = self._get_whole_name(node) + if whole_name: + module_name = whole_name.split('.')[0] + for ref_name, ref_info in self._external_references.items(): + external_ref = ref_info['external_ref_info'] + # external reference is convertible module + if external_ref.name in APIAnalysisSpec.get_convertible_external_names(): + # import from the same external module + if module_name == ref_name.split('.')[0]: + check_result = True + break + + return check_result + + @staticmethod + def _get_external_node(external_references): + """Get all external reference nodes.""" + external_nodes = {} + for ref_name, ref_info in external_references.items(): + external_nodes.update({ref_info['external_ref_info'].node: ref_name}) + return external_nodes + + @staticmethod + def _get_convertible_external_node(external_name_ref): + """Get all convertible external reference nodes.""" + convertible_external_nodes = {} + for ref_name, ref_info in external_name_ref.items(): + if ref_info['external_ref_info'].name in APIAnalysisSpec.get_convertible_external_names(): + convertible_external_nodes.update({ref_info['external_ref_info'].node: ref_name}) + return convertible_external_nodes + + def _update_external_ref_parent(self, node): + """Set external reference parent node info.""" + external_nodes = self._get_external_node(self._external_references) + convertible_external_nodes = self._get_convertible_external_node(self._external_references) + for name_node in node.names: + if name_node in convertible_external_nodes.keys(): + if len(node.names) > 1: + msg = """\ + Not support multiple imports of torch on one line in your script. line:%s: %s + """ % (node.lineno, pasta.dump(node)) + raise ScriptNotSupport(msg) + if name_node in external_nodes.keys(): + ref_name = external_nodes[name_node] + self._external_references[ref_name]['parent_node'] = node + + @staticmethod + def _get_class_scope(node_scope): + """Find the class scope of the node_scope.""" + parent_scope = node_scope.parent_scope + class_scope = None + while parent_scope: + if isinstance(parent_scope.node, ast.ClassDef): + class_scope = parent_scope + break + parent_scope = parent_scope.parent_scope + return class_scope + + def _update_convertible_functions(self, node): + """Update convertible functions.""" + node_scope = self._root_scope.lookup_scope(node) + class_scope = self._get_class_scope(node_scope) + if class_scope: + network_classes = self._network_classes.get(class_scope, []) + if node_scope not in network_classes: + network_classes.append(node_scope) + else: + if node_scope not in self._network_functions: + self._network_functions.append(node_scope) + + def visit_ClassDef(self, node): + """Callback function when visit AST tree""" + if not self._stack[-1] is node: + return + + for base in node.bases: + if self._is_ref_convertible_imports(base): + self._network_classes[self._root_scope.lookup_scope(node)] = [] + + self.generic_visit(node) + + def visit_Import(self, node): + """Callback function when visit AST tree""" + self._update_external_ref_parent(node) + self.generic_visit(node) + + def visit_ImportFrom(self, node): + """Callback function when visit AST tree""" + self._update_external_ref_parent(node) + self.generic_visit(node) + + def visit_Call(self, node): + """Callback function when visit AST tree""" + if not self._stack[-1] is node: + return + is_in_network_function = False + # If torch call is happened in the function, save the function for network definition. + if self._functions_stack and self._is_ref_convertible_imports(node.func): + self._update_convertible_functions(self._functions_stack[-1]) + is_in_network_function = True + if not is_in_network_function: + self.generic_visit(node) + + def visit_FunctionDef(self, node): + """Callback function when visit AST tree""" + if not self._stack[-1] is node: + return + if node.name == "forward": + self._update_convertible_functions(node) + + self._functions_stack.append(node) + self.generic_visit(node) + self._functions_stack.pop() + + def get_name(self, node): + """ + Get the node name. + + Args: + node (AST): The ast node of the source code. + + Returns: + str, the name of the node + """ + if isinstance(node, pasta.base.scope.Scope): + items = [self.get_name(node.node)] + parent_scope = node.parent_scope + while parent_scope: + if not isinstance(parent_scope.node, ast.Module): + items.append(self.get_name(parent_scope.node)) + parent_scope = parent_scope.parent_scope + return '.'.join(reversed(items)) + if isinstance(node, (ast.ClassDef, ast.FunctionDef)): + return node.name + if isinstance(node, (ast.Name, ast.Attribute)): + return self._get_full_name(node) + return str(node) + + def lookup_scope(self, node): + """ + Search the scope of the node. + + Args: + node (AST): The ast node of the source code. + + Returns: + scope, the scope of the node + """ + if isinstance(node, pasta.base.scope.Scope): + return node + return self._root_scope.lookup_scope(node) diff --git a/mindinsight/mindconverter/common/exceptions.py b/mindinsight/mindconverter/common/exceptions.py new file mode 100644 index 0000000000000000000000000000000000000000..494b2dcef8474ed7f477496cad569c645318aa3f --- /dev/null +++ b/mindinsight/mindconverter/common/exceptions.py @@ -0,0 +1,44 @@ +# Copyright 2019 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Define custom exception.""" +from enum import unique + +from mindinsight.utils.constant import ScriptConverterErrors +from mindinsight.utils.exceptions import MindInsightException + + +@unique +class ConverterErrors(ScriptConverterErrors): + """Converter error codes.""" + SCRIPT_NOT_SUPPORT = 1 + NODE_TYPE_NOT_SUPPORT = 2 + + +class ScriptNotSupport(MindInsightException): + """The script can not support to process.""" + + def __init__(self, msg): + super(ScriptNotSupport, self).__init__(ConverterErrors.SCRIPT_NOT_SUPPORT, + msg, + http_code=400) + + +class NodeTypeNotSupport(MindInsightException): + """The astNode can not support to process.""" + + def __init__(self, msg): + super(NodeTypeNotSupport, self).__init__(ConverterErrors.NODE_TYPE_NOT_SUPPORT, + msg, + http_code=400) diff --git a/mindinsight/mindconverter/converter.py b/mindinsight/mindconverter/converter.py index 575f6a2ed7e0ac53bcad294c5a07a3e587fe9ccd..43da14bb25f9f0bb0c07cfc2fd6adac708b967be 100644 --- a/mindinsight/mindconverter/converter.py +++ b/mindinsight/mindconverter/converter.py @@ -13,463 +13,88 @@ # limitations under the License. # ============================================================================ """converter module""" -import copy -import importlib -import inspect import os import stat -from mindinsight.mindconverter.config import ALL_MAPPING -from mindinsight.mindconverter.config import NN_LIST -from mindinsight.mindconverter.config import ALL_TORCH_APIS -from mindinsight.mindconverter.config import ALL_2P_LIST -from mindinsight.mindconverter.config import UNSUPPORTED_WARN_INFOS -from mindinsight.mindconverter.config import ALL_UNSUPPORTED -from mindinsight.mindconverter.common.log import logger -from mindinsight.mindconverter.forward_call import ForwardCall +import pasta -LINE_NO_INDEX_DIFF = 1 +from mindinsight.mindconverter.common.exceptions import ScriptNotSupport +from mindinsight.mindconverter.common.log import logger +from mindinsight.mindconverter.ast_edits import AstEditVisitor class Converter: """Convert class""" - convert_info = '' flags = os.O_WRONLY | os.O_CREAT | os.O_EXCL modes = stat.S_IWUSR | stat.S_IRUSR - @staticmethod - def is_local_defined(obj, member): - """ - Check if obj and member are both defined in the same source file. - - Args: - obj (Union[object, module]): A module or a class. - member (func): A function of obj. - - Returns: - bool, True or False. - """ - srcfile = inspect.getsourcefile(obj) - return inspect.getsourcefile(member) == srcfile + def __init__(self): + self._tree = None + self._infile = None + self._code_analyzer = None + self._ast_editor = None + self._report = [] - @classmethod - def is_valid_module(cls, obj, member): - """ - Check if obj and member defined in same source file and member is inherited from torch.nn.Module. - - Args: - obj (Union[object, module]): A module or a class. - member (func): A function. - - Returns: - bool, True or False. - """ - if inspect.isclass(member): - is_subclass = member.__base__.__name__ in ['Module', - 'Sequential', - 'ModuleList', - 'ModuleDict', - 'ParameterList', - 'ParameterDict'] - return is_subclass and cls.is_local_defined(obj, member) - return False - - @classmethod - def is_valid_function(cls, obj, member): - """ - Check if member is function and defined in the file same as obj. - - Args: - obj (Union[object, module]: The obj. - member (func): The func. - - Returns: - bool, True or False. - """ - return inspect.isfunction(member) and cls.is_local_defined(obj, member) - - @staticmethod - def find_left_parentheses(string, right): - """ - Find index of the first left parenthesis. - - Args: - string (str): A line of code. - right (int): The right index for string to find from. - - Returns: - int, index of the first parenthesis. - - Raises: - ValueError: If line of code doesn't contain any pair of `()` or `(` and `)` are not paired. - """ - if string[right] != ')': - raise ValueError('code [{}] at index {} not ")".'.format(string, right)) - stack = [] - for i in range(right, -1, -1): - if string[i] == ')': - stack.append(')') - elif string[i] == '(': - stack.pop() - if not stack: - return i - raise ValueError("{} should contain ()".format(string)) - - @staticmethod - def find_right_parentheses(string, left): + def convert(self, infile, output_dir, report_dir): """ - Find first index of right parenthesis which make all left parenthesis make sense. + Convert a module's code, code converted will be save in output_dir, and a report will be save in report_dir. Args: - string (str): A line of code. - left (int): Start index of string to find from. - - Returns: - int, index of the found right parenthesis. - - Raises: - ValueError: If line of code doesn't contain any pair of `()` or `(` and `)` are not paired. + infile (str): The script to convert. + output_dir (str): The path to save converted file. + report_dir (str): The path to save report file. """ - stack = [] - for i in range(left, len(string)): - if string[i] == '(': - stack.append('(') - elif string[i] == ')': - stack.pop() - if not stack: - return i - raise ValueError("{} should contain ()".format(string)) + in_file_split = _path_split(infile) + in_file_split[-1], _ = _get_name_ext(in_file_split[-1]) + module_name = '.'.join(in_file_split) + with open(infile, 'r') as file: + content = ''.join(file.readlines()) + + self._infile = infile + self._tree = pasta.parse(content) + self._report.clear() + try: + logger.info("Script file is %s", infile) + logger.info("Start converting %s", module_name) + self._report.append('[Start Convert]') + self._ast_editor = AstEditVisitor() + self._ast_editor.process(self._tree) + self._report.extend(self._ast_editor.get_logs()) + self._report.append('[Convert Over]') + dest_file = os.path.join(output_dir, os.path.basename(infile)) + with os.fdopen(os.open(dest_file, self.flags, self.modes), 'w') as file: + file.write(pasta.dump(self._tree)) + logger.info("Convert success. Result is wrote to %s.", dest_file) + except ScriptNotSupport as error: + self._report.append('[ScriptNotSupport] ' + error.message) + self._report.append('[Convert failed]') + raise error + except Exception as error: + self._report.clear() + raise error + finally: + if self._report: + dest_report_file = os.path.join(report_dir, + '_'.join(os.path.basename(infile).split('.')[:-1]) + '_report.txt') + with os.fdopen(os.open(dest_report_file, self.flags, self.modes), 'a') as file: + file.write('\n'.join(self._report)) + logger.info("Convert report is saved in %s", dest_report_file) @staticmethod - def get_call_name(code, end): + def convert_api(source_code): """ - Traverse code in a reversed function from index end and get the call name and start index of the call name, - if call name not found, return a null character string and -1 + Convert api_name in code to MindSpore api, if api_name is a python api, code will not convert. Args: - code (str): The str of code to find from. - end (int): Start index to find. - - Returns: - tuple(str, int), one is founded api name if found, else a null character string, the other is start index - of founded api name, -1 if api name not found - """ - stack = [] - for i in range(end - 1, -1, -1): - if code[i] in ["(", "[", "{"]: - if stack: - stack.pop() - else: - return code[i + 1:end], i + 1 - elif code[i] in [")", "]", "}"]: - stack.append(code[i]) - elif stack: - continue - elif not (code[i].isalpha() or code[i].isdigit() or code[i] == '_' or code[i] == '.'): - return code[i + 1:end], i + 1 - return "", -1 - - def convert_api(self, code, start, api_name=""): - """ - Convert api_name in code to MindSpore api with start as a start index, if api_name is a python api, - code will not convert. - - Args: - code (str): The str code to convert. - start (int): The index of code to start convert from. - api_name (str): The api name to convert. - + source_code (ast.Call): The ast node to convert. Returns: str, the converted code. - int, index of converted api_name in code. - """ - # handle format like .shape( - if api_name.startswith('.'): - call_name, new_start = self.get_call_name(code, start) - if start == -1 or call_name == "self": - return code, start + 1 - else: - call_name = api_name - new_start = start - - # find full api expected to be converted. eg:expr="nn.Conv2d(1,2,3)" args_str="(1,2,3)" - left = code.find("(", start) - if left == -1: - raise ValueError('"(" not found, {} should work with "("'.format(call_name)) - right = self.find_right_parentheses(code, left) - end = right - - expr = code[start:end + 1] - args_str = code[left:right + 1] - - map_helper = ALL_MAPPING[api_name] - new_expr = map_helper.convert(call_name, args_str) - next_newline = code.find("\n", end + 1) - fill_num = (expr.count("\n") - new_expr.count("\n")) - if next_newline != -1: - code = code[:new_start] + new_expr + code[end + 1:next_newline] + ("\n" * fill_num) + code[next_newline:] - else: - code = code[:new_start] + new_expr + ")" + ("\n" * fill_num) + code[end + 2:] - - return code, start + len(map_helper.ms_api.name) - - @staticmethod - def find_api(code, i, is_forward): - """ - Find api name from code with a start index i, check api name ok with a is_forward condition. - - Args: - code (str): The code from which to find api name. - i (int): The start index to find. - is_forward (bool): Check if the found api name ok. - - Returns: - str, api name if find api name and check ok with is_forward condition, else a null character string. - """ - if code[i:].startswith("nn.") \ - or code[i:].startswith("F.") \ - or code[i:].startswith("torch.") \ - or code[i:].startswith('.'): - j = code.find('(', i) - if j != -1 and code[i:j] in ALL_TORCH_APIS: - api_name = code[i:j] - if (not is_forward and api_name in NN_LIST) or (is_forward and api_name in ALL_2P_LIST): - return api_name - return "" - - def convert_function(self, fun_name, fun, is_forward): - """ - Convert a PyTorch function into MindSpore function. - - Args: - fun_name (str): The str of function name. - fun (func): The function to convert. - is_forward (bool): If the function is defined in forward function in nn.Module in torch. - - Returns: - dict, old code and converted code map if convert happens, else {}. """ - _, line_no = inspect.getsourcelines(fun) - logger.info("Line %3d: start converting function %s()", line_no, fun_name) - - code = inspect.getsource(fun) - code_saved = copy.copy(code) - - i = 0 - while i < len(code): - api_name = self.find_api(code, i, is_forward) - if api_name: - line_no1 = line_no + code[:i].count('\n') - if api_name in ALL_MAPPING: - logger.info("Line %3d start converting API: %s", line_no1, api_name) - code, i = self.convert_api(code, i, api_name) - self.convert_info += "[Convert][Line{:3d}] {} is converted.\n".format(line_no1, api_name) - continue - if api_name in ALL_UNSUPPORTED: - warn_info = ". " + UNSUPPORTED_WARN_INFOS[api_name] if api_name in UNSUPPORTED_WARN_INFOS else "" - logger.warning("Line %3d: found unsupported API: %s%s", line_no1, api_name, warn_info) - self.convert_info += "[Unconvert][Line{:3d}] {} didn't convert{}\n".format(line_no1, - api_name, warn_info) - i += 1 - return {code_saved: code} if code_saved != code else {} - - @staticmethod - def judge_forward(name, forward_list): - """ - Check if function is a forward function. - - Args: - name (str): The function name. - forward_list (set): A set of forward function. - - Returns: - bool, True or False - """ - is_forward = name in forward_list or name.split(".")[-1] == "forward" - if is_forward: - logger.debug("%s is a forward function", name) - return is_forward - - def convert_module(self, module_name, module, forward_list): - """ - Convert a PyTorch module code into MindSpore module code. - - Args: - module_name (str): The module's name. - module (module): The module to convert. - forward_list (set): A set of forward function. - - Returns: - dict, map of old code and converted code. - """ - _, line_no = inspect.getsourcelines(module) - logger.info("Line {:3d}: start converting nn.Module {}".format(line_no, module_name)) - - mapped = {} - for name, member in inspect.getmembers(module): - if self.is_valid_function(module, member): - is_forward = self.judge_forward("{}.{}".format(module_name, name), forward_list) - mapped.update(self.convert_function(name, member, is_forward)) - return mapped - - def get_mapping(self, import_mod, forward_list): - """ - Convert code of a module and get mapping of old code and convert code. - - Args: - import_mod (module): The module to convert. - forward_list (set): A set of forward function. - - Returns: - dict, mapping for old code and converted code of the module - """ - mapping = {} - tasks = [] - for name, member in inspect.getmembers(import_mod): - if self.is_valid_module(import_mod, member): - _, line_no = inspect.getsourcelines(member) - tasks.append((line_no, self.convert_module, (name, member, forward_list))) - elif self.is_valid_function(import_mod, member): - _, line_no = inspect.getsourcelines(member) - is_forward = self.judge_forward("{}.{}".format(import_mod, name), forward_list) - tasks.append((line_no, self.convert_function, (name, member, is_forward))) - tasks.sort() - for _, convert_fun, args in tasks: - mapping.update(convert_fun(*args)) - return mapping - - @staticmethod - def get_code_start_line_num(source_lines): - """ - Get the start code line number exclude comments. - - Args: - source_lines (list[str]): Split results of original code. - - Returns: - int, the start line number. - """ - stack = [] - index = 0 - for i, line in enumerate(source_lines): - if line.strip().startswith('#'): - continue - if line.strip().startswith('"""'): - if not line.endswith('"""\n'): - stack.append('"""') - continue - if line.strip().startswith("'''"): - if not line.endswith("'''\n"): - stack.append("'''") - continue - if line.endswith('"""\n') or line.endswith("'''\n"): - stack.pop() - continue - if line.strip() != '' and not stack: - index = i - break - return index - - def update_code_and_convert_info(self, code, mapping): - """ - Replace code according to mapping, and update convert info. - - Args: - code (str): The code to replace. - mapping (dict): Mapping for original code and the replaced code. - - Returns: - str, the replaced code. - """ - - for key, value in mapping.items(): - code = code.replace(key, value) - - source_lines = code.splitlines(keepends=True) - start_line_number = self.get_code_start_line_num(source_lines) - add_import_infos = ['import mindspore\n', - 'import mindspore.nn as nn\n', - 'import mindspore.ops.operations as P\n'] - for i, add_import_info in enumerate(add_import_infos): - source_lines.insert(start_line_number + i, add_import_info) - self.convert_info += '[Add Import] {}.\n'.format(add_import_info.strip()) - - insert_count = len(add_import_infos) - line_diff = insert_count - LINE_NO_INDEX_DIFF - - for i in range(start_line_number + insert_count, len(source_lines)): - line = source_lines[i] - - if (line.startswith('from torch') and 'import' in line) or line.startswith('import torch'): - new_line = '# ' + line - source_lines[i] = new_line - self.convert_info += '[Annotate][Line{:3d}] {} is annotated.\n'.format(i - line_diff, line.strip()) - if line.strip().startswith('class') and '(nn.Module)' in line: - new_line = line.replace('nn.Module', 'nn.Cell') - source_lines[i] = new_line - self.convert_info += '[Convert][Line{:3d}] nn.Module is converted.\n'.format(i - line_diff) - if line.strip().startswith('def forward('): - new_line = line.replace('forward', 'construct') - source_lines[i] = new_line - self.convert_info += '[Convert][Line{:3d}] forward is converted.\n'.format(i - line_diff) - if 'nn.Linear' in line: - new_line = line.replace('nn.Linear', 'nn.Dense') - source_lines[i] = new_line - self.convert_info += '[Convert][Line{:3d}] nn.Linear is converted.\n'.format(i - line_diff) - if '(nn.Sequential)' in line: - new_line = line.replace('nn.Sequential', 'nn.SequentialCell') - source_lines[i] = new_line - self.convert_info += '[Convert][Line{:3d}] nn.Sequential is converted.\n'.format(i - line_diff) - if 'nn.init.' in line: - new_line = line.replace('nn.init', 'pass # nn.init') - source_lines[i] = new_line - self.convert_info += '[Annotate][Line{:3d}] {} is annotated.\n'.format(i - line_diff, 'nn.init') - - code = ''.join(source_lines) - return code - - def convert(self, import_name, output_dir, report_dir): - """ - Convert a module's code, code converted will be save in output_dir, and a report will be save in report_dir. - - Args: - import_name (str): The module from which to import the module to convert. - output_dir (str): The path to save converted file. - report_dir (str): The path to save report file. - """ - logger.info("Start converting %s", import_name) - start_info = '[Start Convert]\n' - module_info = 'The module is {}.\n'.format(import_name) - - import_mod = importlib.import_module(import_name) - srcfile = inspect.getsourcefile(import_mod) - logger.info("Script file is %s", srcfile) - - forward_list = set(ForwardCall(srcfile).calls) - logger.debug("Forward_list: %s", forward_list) - - # replace python function under nn.Module - mapping = self.get_mapping(import_mod, forward_list) - code = inspect.getsource(import_mod) - code = self.update_code_and_convert_info(code, mapping) - convert_info_split = self.convert_info.splitlines(keepends=True) - convert_info_split = sorted(convert_info_split) - convert_info_split.insert(0, start_info) - convert_info_split.insert(1, module_info) - convert_info_split.append('[Convert Over]') - self.convert_info = ''.join(convert_info_split) - - dest_file = os.path.join(output_dir, os.path.basename(srcfile)) - with os.fdopen(os.open(dest_file, self.flags, self.modes), 'w') as file: - file.write(code) - logger.info("Convert success. Result is wrote to %s.", dest_file) - - dest_report_file = os.path.join(report_dir, - '_'.join(os.path.basename(srcfile).split('.')[:-1]) + '_report.txt') - with os.fdopen(os.open(dest_report_file, self.flags, self.modes), 'a') as file: - file.write(self.convert_info) - logger.info("Convert report is saved in %s", dest_report_file) + ast_node = pasta.parse(source_code).body[0].value + check_context = False + replaced_code = AstEditVisitor().mapping_api(ast_node, check_context) + return replaced_code def _get_name_ext(file): @@ -514,14 +139,6 @@ def main(files_config): files_config (dict): The config of files which to convert. """ convert_ins = Converter() - root_path = files_config['root_path'] in_files = files_config['in_files'] for in_file in in_files: - in_file_split = _path_split(in_file[len(root_path):]) - in_file_split[-1], _ = _get_name_ext(in_file_split[-1]) - module_name = '.'.join(in_file_split) - convert_ins.convert(module_name, files_config['outfile_dir'], files_config['report_dir']) - - in_module = files_config.get('in_module') - if in_module: - convert_ins.convert(in_module, files_config['outfile_dir'], files_config['report_dir']) + convert_ins.convert(in_file, files_config['outfile_dir'], files_config['report_dir']) diff --git a/mindinsight/mindconverter/forward_call.py b/mindinsight/mindconverter/forward_call.py index 07e52707300a275ef5523a2a513f0bf6bf3105aa..486303125276fa6affeff65a9f20f71da3117d45 100644 --- a/mindinsight/mindconverter/forward_call.py +++ b/mindinsight/mindconverter/forward_call.py @@ -14,7 +14,8 @@ # ============================================================================ """Find out forward functions of script file""" import ast -import os + +import pasta class ForwardCall(ast.NodeVisitor): @@ -24,73 +25,80 @@ class ForwardCall(ast.NodeVisitor): Find the sub functions called by the forward function in the script file. """ - def __init__(self, filename): - self.filename = filename - self.module_name = os.path.basename(filename).replace('.py', '') - self.name_stack = [] - self.forward_stack = [] - self.calls = set() + def __init__(self, ast_tree): + self._tree = ast_tree + self._name_stack = [] + self._forward_stack = [] + self.calls = {} # key is function name, value is forward function ast node. + self._function_list = {} # key is function name, value is function ast node. self.process() def process(self): - """Parse the python source file to find the forward functions.""" - with open(self.filename, 'rt', encoding='utf-8') as file: - content = file.read() - self.visit(ast.parse(content, self.filename)) + """visit ast tree to find the forward functions.""" + self.visit(self._tree) + # first visit to find out all functions, so restores all variables except _function_list + self._name_stack.clear() + self._forward_stack.clear() + self.calls.clear() + self.visit(self._tree) def get_current_namespace(self): """Get the namespace when visit the AST node""" - namespace = '.'.join(self.name_stack) + namespace = '.'.join(self._name_stack) return namespace @classmethod - def get_ast_node_name(cls, node): - """Get AST node name.""" - if isinstance(node, ast.Attribute): - return f'{cls.get_ast_node_name(node.value)}.{node.attr}' - - if isinstance(node, ast.Name): - return node.id + def get_call_name(cls, node): + """Get functional call name.""" + if not isinstance(node, ast.Call): + return None - return node + return pasta.dump(node.func) def visit_ClassDef(self, node): """Callback function when visit AST tree""" - self.name_stack.append(node.name) + self._name_stack.append(node.name) self.generic_visit(node) - self.name_stack.pop() + self._name_stack.pop() def visit_FunctionDef(self, node): """Callback function when visit AST tree""" + namespace = self.get_current_namespace() + if namespace: + func_name = f'{namespace}.{node.name}' + else: + func_name = node.name func_name = f'{self.get_current_namespace()}.{node.name}' is_in_chain = func_name in self.calls or node.name == 'forward' if is_in_chain: - self.forward_stack.append(func_name) + self._forward_stack.append(func_name) if node.name == 'forward': - self.calls.add(func_name) + self.calls.update({func_name: node}) + self._function_list.update({func_name: node}) self.generic_visit(node) if is_in_chain: - self.forward_stack.pop() + self._forward_stack.pop() def visit_Call(self, node): """Callback function when visit AST tree""" for arg in node.args: self.visit(arg) - for kw in node.keywords: - self.visit(kw.value) - func_name = self.get_ast_node_name(node.func) + for keyword in node.keywords: + self.visit(keyword.value) + func_name = self.get_call_name(node) if isinstance(node.func, ast.Name): if func_name not in ['super', 'str', 'repr']: - if self.forward_stack: - self.calls.add(func_name) + if self._forward_stack: + self.calls.update({func_name: self._function_list.get(func_name)}) self.visit(node.func) else: - if self.forward_stack: - if 'self' in func_name: - self.calls.add(f'{self.get_current_namespace()}.{func_name.split(".")[-1]}') + if self._forward_stack: + if func_name.startswith('self.'): + whole_name = f'{self.get_current_namespace()}.{func_name.split(".")[-1]}' + self.calls.update({whole_name: self._function_list.get(whole_name)}) else: - self.calls.add(func_name) + self.calls.update({func_name: self._function_list.get(func_name)}) self.visit(node.func) diff --git a/mindinsight/utils/constant.py b/mindinsight/utils/constant.py index 348da1cb902170223ebc69de23e088cc9a1cf30e..87f39d680088ed60c5e6fcfbdfca2592d5ae7619 100644 --- a/mindinsight/utils/constant.py +++ b/mindinsight/utils/constant.py @@ -30,6 +30,7 @@ class MindInsightModules(Enum): LINEAGEMGR = 2 DATAVISUAL = 5 PROFILERMGR = 6 + SCRIPTCONVERTER = 7 class GeneralErrors(Enum): @@ -69,3 +70,7 @@ class DataVisualErrors(Enum): SCALAR_NOT_EXIST = 14 HISTOGRAM_NOT_EXIST = 15 TRAIN_JOB_DETAIL_NOT_IN_CACHE = 16 + + +class ScriptConverterErrors(Enum): + """Enum definition for mindconverter errors.""" diff --git a/tests/ut/mindconverter/test_converter.py b/tests/ut/mindconverter/test_converter.py index 44ad641b7a182f3fac6bad5b4e70325145eca00c..c55ba1f2e817044aaed7385ae0cdb66b4db344c8 100644 --- a/tests/ut/mindconverter/test_converter.py +++ b/tests/ut/mindconverter/test_converter.py @@ -22,380 +22,201 @@ class TestConverter: converter_ins = Converter() - def test_judge_forward(self): - """test judge_forward""" - name1 = 'conv1' - forward_list = {'conv1', 'relu'} - result1 = self.converter_ins.judge_forward(name1, forward_list) - assert result1 is True - - name2 = 'self.forward' - result2 = self.converter_ins.judge_forward(name2, forward_list) - assert result2 is True - - def test_find_left_parentheses(self): - """test find_left_parentheses""" - code = '''nn.Sequential(nn.Conv2d(in_dim, 6, 5, stride=1, padding=0, ), - nn.ReLU(), - nn.ReLU(True), - nn.MaxPool2d(2, 2), - nn.Conv2d(6, 16, 5, stride=1, padding=0), - nn.ReLU(inplace=False), - nn.MaxPool2d(2, 2))''' - right_index = len(code) - 1 - left_index = code.index('nn.Conv2d') - result = self.converter_ins.find_left_parentheses(code, right_index) - assert result == left_index - 1 - - def test_find_api(self): - """test find_api""" - code = '''nn.Sequential(nn.Conv2d(in_dim, 6, 5, stride=1, padding=0, ), - nn.ReLU(), - nn.ReLU(True), - nn.MaxPool2d(2, 2), # TODO padding - nn.Conv2d(6, 16, 5, stride=1, padding=0), - nn.ReLU(inplace=False), - nn.MaxPool2d(2, 2))''' - index = 0 - is_forward = False - result = self.converter_ins.find_api(code, index, is_forward) - assert result == 'nn.Sequential' - - def test_get_call_name(self): - """test get_call_name""" - code = '''nn.Sequential(nn.Conv2d(in_dim, 6, 5, stride=1, padding=0))''' - end = len(code) - call_name, index = self.converter_ins.get_call_name(code, end) - - assert call_name == '' - assert index == -1 - - def test_find_right_parentheses(self): - """test find_right_parentheses""" - code = '''nn.Sequential(nn.Conv2d(in_dim, 6, 5, stride=1, padding=0, ), - nn.ReLU(), - nn.ReLU(True), - nn.MaxPool2d(2, 2), # TODO padding - nn.Conv2d(6, 16, 5, stride=1, padding=0), - nn.ReLU(inplace=False), - nn.MaxPool2d(2, 2))''' - left_index = 0 - result = self.converter_ins.find_right_parentheses(code, left_index) - assert_index = len(code) - 1 - assert result == assert_index - # test convert_api with nn ops def test_convert_api_nn_layernorm(self): """Test convert_api function work ok when convert api nn.LayerNorm""" - code = """ - def __init__(self, num_classes=1000): - self.features = nn.SequentialCell([ - nn.LayerNorm((5, 10, 10), elementwise_affine=False), - nn.ReLU(inplace=False) - ]) - """ + code = "nn.LayerNorm((5, 10, 10), elementwise_affine=False)" api_name = 'nn.LayerNorm' - start = code.find(api_name) layer_norm_info = NN_MAPPING.get(api_name) expected_ms_api_name = 'nn.LayerNorm' epsilon = layer_norm_info.pt_api.params.get('eps') - replaced_code, new_start = self.converter_ins.convert_api(code, start, api_name) + replaced_code = self.converter_ins.convert_api(code) assert replaced_code == code.replace('nn.LayerNorm((5, 10, 10), elementwise_affine=False)', '{}(normalized_shape=(5, 10, 10), epsilon={})'.format( expected_ms_api_name, epsilon)) - assert new_start == start + len(expected_ms_api_name) def test_convert_api_nn_leaky_relu(self): """Test convert_api function work ok when convert api nn.LeakyReLU""" - code = """ - def __init__(self, num_classes=1000): - self.features = nn.SequentialCell([ - nn.LayerNorm((5, 10, 10), elementwise_affine=False), - nn.LeakyReLU(0.3)]) - """ - api_name = 'nn.LeakyReLU' - start = code.find(api_name) + code = "nn.LeakyReLU(0.3)" expected_ms_api_name = 'nn.LeakyReLU' - replaced_code, new_start = self.converter_ins.convert_api(code, start, api_name) + replaced_code = self.converter_ins.convert_api(code) assert replaced_code == code.replace('nn.LeakyReLU(0.3)', '{}(alpha=0.3)'.format(expected_ms_api_name)) - assert new_start == start + len(expected_ms_api_name) def test_convert_api_nn_prelu(self): """Test convert_api function work ok when convert api nn.PReLU""" - code = """ - input = torch.randn(2, 3, 5) - nn.PReLU()(input) - - """ - api_name = 'nn.PReLU' - start = code.find(api_name) + code = "nn.PReLU()(input)" expected_ms_api_name = 'nn.PReLU' - replaced_code, new_start = self.converter_ins.convert_api(code, start, api_name) + replaced_code = self.converter_ins.convert_api(code) assert replaced_code == code.replace('nn.PReLU()(input)', '{}()(input)'.format(expected_ms_api_name)) - assert new_start == start + len(expected_ms_api_name) def test_convert_api_nn_softmax(self): """Test convert_api function work ok when convert api nn.Softmax""" - code = """ - nn.Softmax(dim=1)(input) - """ - api_name = 'nn.Softmax' + code = "nn.Softmax(dim=1)" expected_ms_api_name = 'nn.Softmax' - start = code.find(api_name) - replaced_code, new_start = self.converter_ins.convert_api(code, start, api_name) - assert replaced_code == code.replace('nn.Softmax(dim=1)(input)', - '{}(axis=1)(input)'.format(expected_ms_api_name)) - assert new_start == start + len(expected_ms_api_name) + replaced_code = self.converter_ins.convert_api(code) + assert replaced_code == code.replace('nn.Softmax(dim=1)', + '{}(axis=1)'.format(expected_ms_api_name)) # test convert_api with torch dot ops def test_convert_api_torch_dot_abs(self): """Test convert_api function work ok when convert api torch.abs""" - code = """ - torch.abs(input) - """ - api_name = 'torch.abs' - start = code.find(api_name) + code = "torch.abs(input)" expected_ms_api_name = 'P.Abs' - replaced_code, new_start = self.converter_ins.convert_api(code, start, api_name) + replaced_code = self.converter_ins.convert_api(code) assert replaced_code == code.replace('torch.abs(input)', '{}()(input)'.format(expected_ms_api_name)) - assert new_start == start + len(expected_ms_api_name) def test_convert_api_torch_dot_acos(self): """Test convert_api function work ok when convert api torch.acos""" - code = """ - torch.acos(input) - """ - api_name = 'torch.acos' - start = code.find(api_name) + code = "torch.acos(input)" expected_ms_api_name = 'P.ACos' - replaced_code, new_start = self.converter_ins.convert_api(code, start, api_name) + replaced_code = self.converter_ins.convert_api(code) assert replaced_code == code.replace('torch.acos(input)', '{}()(input)'.format(expected_ms_api_name)) - assert new_start == start + len(expected_ms_api_name) def test_convert_api_torch_dot_cos(self): """Test convert_api function work ok when convert api torch.cos""" - code = """ - torch.cos(input) - """ - api_name = 'torch.cos' + code = "torch.cos(input)" expected_ms_api_name = 'P.Cos' - start = code.find(api_name) - replaced_code, new_start = self.converter_ins.convert_api(code, start, api_name) + replaced_code = self.converter_ins.convert_api(code) assert replaced_code == code.replace('torch.cos(input)', '{}()(input)'.format(expected_ms_api_name)) - assert new_start == start + len(expected_ms_api_name) def test_convert_api_torch_dot_exp(self): """Test convert_api function work ok when convert api torch.exp""" - code = """ - torch.exp(input) - """ - api_name = 'torch.exp' + code = "torch.exp(input)" expected_ms_api_name = 'P.Exp' - start = code.find(api_name) - replaced_code, new_start = self.converter_ins.convert_api(code, start, api_name) + replaced_code = self.converter_ins.convert_api(code) assert replaced_code == code.replace('torch.exp(input)', '{}()(input)'.format(expected_ms_api_name)) - assert new_start == start + len(expected_ms_api_name) def test_convert_api_torch_dot_log(self): """Test convert_api function work ok when convert api torch.log""" - code = """ - torch.log(input) - """ - api_name = 'torch.log' + code = "torch.log(input)" expected_ms_api_name = 'P.Log' - start = code.find(api_name) - replaced_code, new_start = self.converter_ins.convert_api(code, start, api_name) + replaced_code = self.converter_ins.convert_api(code) assert replaced_code == code.replace('torch.log(input)', '{}()(input)'.format(expected_ms_api_name)) - assert new_start == start + len(expected_ms_api_name) def test_convert_api_torch_dot_pow(self): """Test convert_api function work ok when convert api torch.pow""" - code = """ - torch.pow(a, exp) - """ - api_name = 'torch.pow' + code = "torch.pow(a, exp)" expected_ms_api_name = 'P.Pow' - start = code.find(api_name) - replaced_code, new_start = self.converter_ins.convert_api(code, start, api_name) + replaced_code = self.converter_ins.convert_api(code) assert replaced_code == code.replace('torch.pow(a, exp)', '{}()(a, exp)'.format(expected_ms_api_name)) - assert new_start == start + len(expected_ms_api_name) def test_convert_api_torch_dot_div(self): """Test convert_api function work ok when convert api torch.div""" - code = """ - input = torch.randn(5) - other = torch.randn(5) - torch.div(input, other) - """ - api_name = 'torch.div' + code = "torch.div(input, other)" expected_ms_api_name = 'P.Div' - start = code.find(api_name) - replaced_code, new_start = self.converter_ins.convert_api(code, start, api_name) + replaced_code = self.converter_ins.convert_api(code) assert replaced_code == code.replace('torch.div(input, other)', '{}()(input, other)'.format(expected_ms_api_name)) - assert new_start == start + len(expected_ms_api_name) def test_convert_api_torch_dot_sin(self): """Test convert_api function work ok when convert api torch.sin""" - code = """ - torch.sin(input) - """ - api_name = 'torch.sin' + code = "torch.sin(input)" expected_ms_api_name = 'P.Sin' - start = code.find(api_name) - replaced_code, new_start = self.converter_ins.convert_api(code, start, api_name) + replaced_code = self.converter_ins.convert_api(code) assert replaced_code == code.replace('torch.sin(input)', '{}()(input)'.format(expected_ms_api_name)) - assert new_start == start + len(expected_ms_api_name) def test_convert_api_torch_dot_sqrt(self): """Test convert_api function work ok when convert api torch.sqrt""" - code = """ - torch.sqrt(input) - """ - api_name = 'torch.sqrt' + code = "torch.sqrt(input)" expected_ms_api_name = 'P.Sqrt' - start = code.find(api_name) - replaced_code, new_start = self.converter_ins.convert_api(code, start, api_name) + replaced_code = self.converter_ins.convert_api(code) assert replaced_code == code.replace('torch.sqrt(input)', '{}()(input)'.format(expected_ms_api_name)) - assert new_start == start + len(expected_ms_api_name) def test_convert_api_torch_dot_eye_with_n(self): """Test convert_api function work ok when convert api torch.eye""" - code = """ - torch.eye(3) - """ - api_name = 'torch.eye' + code = "torch.eye(3)" expected_ms_api_name = 'P.Eye' - start = code.find(api_name) - replaced_code, new_start = self.converter_ins.convert_api(code, start, api_name) + replaced_code = self.converter_ins.convert_api(code) assert replaced_code == code.replace('torch.eye(3)', '{}()(3, 3, mindspore.int32)'.format(expected_ms_api_name)) - assert new_start == start + len(expected_ms_api_name) def test_convert_api_torch_dot_eye_with_m(self): """Test convert_api function work ok when convert api torch.eye""" - code = """ - torch.eye(3, 4) - """ - api_name = 'torch.eye' + code = "torch.eye(3, 4)" expected_ms_api_name = 'P.Eye' - start = code.find(api_name) - replaced_code, new_start = self.converter_ins.convert_api(code, start, api_name) + replaced_code = self.converter_ins.convert_api(code) assert replaced_code == code.replace('torch.eye(3, 4)', '{}()(3, 4, mindspore.int32)'.format(expected_ms_api_name)) - assert new_start == start + len(expected_ms_api_name) def test_convert_api_torch_dot_add_with_alpha_default(self): """Test convert_api function work ok when convert api torch.add""" - code = """ - torch.add(input, value) - """ - api_name = 'torch.add' + code = "torch.add(input, value)" expected_ms_api_name = 'P.TensorAdd' - start = code.find(api_name) - replaced_code, new_start = self.converter_ins.convert_api(code, start, api_name) + replaced_code = self.converter_ins.convert_api(code) assert replaced_code == code.replace('torch.add(input, value)', '{}()(input, value)'.format(expected_ms_api_name)) - assert new_start == start + len(expected_ms_api_name) def test_convert_api_torch_dot_add_with_alpha_not_default(self): """Test convert_api function work ok when convert api torch.add""" - code = """ - torch.add(input, value, 3) - """ - api_name = 'torch.add' + code = "torch.add(input, value, 3)" expected_ms_api_name = 'P.TensorAdd' - start = code.find(api_name) - replaced_code, new_start = self.converter_ins.convert_api(code, start, api_name) + replaced_code = self.converter_ins.convert_api(code) assert replaced_code == code.replace('torch.add(input, value, 3)', '{}()(input, value*3)'.format(expected_ms_api_name)) - assert new_start == start + len(expected_ms_api_name) # test convert_api with F ops def test_convert_api_f_normalize(self): """Test convert_api function work ok when convert api F.normalize""" - code = """ - input = torch.randn(2, 3, 5) - F.normalize(input) - """ - api_name = 'F.normalize' - start = code.find(api_name) + code = "F.normalize(input)" expected_ms_api_name = 'P.L2Normalize' - replaced_code, new_start = self.converter_ins.convert_api(code, start, api_name) + replaced_code = self.converter_ins.convert_api(code) assert replaced_code == code.replace('F.normalize(input)', '{}(1, 1e-12)(input)'.format(expected_ms_api_name)) - assert new_start == start + len(expected_ms_api_name) def test_convert_api_f_sigmoid(self): """Test convert_api function work ok when convert api F.sigmoid""" - code = """ - input = torch.randn(2, 3, 5) - F.sigmoid(input) - """ - api_name = 'F.sigmoid' - start = code.find(api_name) + code = "F.sigmoid(input)" expected_ms_api_name = 'P.Sigmoid' - replaced_code, new_start = self.converter_ins.convert_api(code, start, api_name) + replaced_code = self.converter_ins.convert_api(code) assert replaced_code == code.replace('F.sigmoid(input)', '{}()(input)'.format(expected_ms_api_name)) - assert new_start == start + len(expected_ms_api_name) # test convert_api with tensor dot ops def test_convert_api_tensor_dot_repeat(self): """Test convert_api function work ok when convert api .repeat""" - code = """ - x.repeat(4, 2) - """ - api_name = '.repeat' - start = code.find(api_name) + code = "x.repeat(4, 2)" expected_ms_api_name = 'P.Tile' - replaced_code, new_start = self.converter_ins.convert_api(code, start, api_name) + replaced_code = self.converter_ins.convert_api(code) assert replaced_code == code.replace('x.repeat(4, 2)', '{}()(x, {})'.format(expected_ms_api_name, '(4, 2,)')) - assert new_start == start + len(expected_ms_api_name) def test_convert_api_tensor_dot_permute(self): """Test convert_api function work ok when convert api .permute""" - code = """ - x.permute(2, 0, 1) - """ - api_name = '.permute' - start = code.find(api_name) + code = "x.permute(2, 0, 1)" expected_ms_api_name = 'P.Transpose' - replaced_code, new_start = self.converter_ins.convert_api(code, start, api_name) + replaced_code = self.converter_ins.convert_api(code) assert replaced_code == code.replace('x.permute(2, 0, 1)', '{}()(x, (2, 0, 1,))'.format(expected_ms_api_name)) - assert new_start == start + len(expected_ms_api_name) diff --git a/tests/ut/mindconverter/test_forward_call.py b/tests/ut/mindconverter/test_forward_call.py index 71928f12931cbfbc7d4d81fb898b76503420a5db..8d01d8a8ce9990769ab941e50475f3a25926d980 100644 --- a/tests/ut/mindconverter/test_forward_call.py +++ b/tests/ut/mindconverter/test_forward_call.py @@ -15,7 +15,6 @@ """Test forward_call module.""" import ast import textwrap -from unittest.mock import patch from mindinsight.mindconverter.forward_call import ForwardCall @@ -50,12 +49,10 @@ class TestForwardCall: return out """) - @patch.object(ForwardCall, 'process') - def test_process(self, mock_process): + def test_process(self): """Test the function of visit ast tree to find out forward functions.""" - mock_process.return_value = None - forward_call = ForwardCall("mock") - forward_call.visit(ast.parse(self.source)) + ast_tree = ast.parse(self.source) + forward_call = ForwardCall(ast_tree) expect_calls = ['TestNet.forward', 'TestNet.forward1', @@ -70,6 +67,6 @@ class TestForwardCall: 'TestNet.fc3', ] expect_calls.sort() - real_calls = list(forward_call.calls) + real_calls = list(forward_call.calls.keys()) real_calls.sort() assert real_calls == expect_calls