提交 bcdc61cc 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!273 Converter: use the AST to analyze and modify network definition script

Merge pull request !273 from ggpolar/br_wzk_dev
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless REQUIRED by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Convert for Python scripts according API mapping information."""
import ast
import logging
import re
from enum import Enum
import pasta
from pasta.augment import import_utils
from mindinsight.mindconverter.code_analysis import CodeAnalyzer
from mindinsight.mindconverter.code_analysis import APIAnalysisSpec
from mindinsight.mindconverter.config import ALL_MAPPING
from mindinsight.mindconverter.config import NN_LIST
from mindinsight.mindconverter.config import ALL_TORCH_APIS
from mindinsight.mindconverter.config import ALL_2P_LIST
from mindinsight.mindconverter.config import UNSUPPORTED_WARN_INFOS
from mindinsight.mindconverter.config import ALL_UNSUPPORTED
from mindinsight.mindconverter.common.log import logger
from mindinsight.mindconverter.common.exceptions import NodeTypeNotSupport
from mindinsight.mindconverter.forward_call import ForwardCall
LOG_FMT_CONVERT = "[Convert] '%s' is converted to '%s'."
LOG_FMT_NOT_CONVERT = "[UnConvert] '%s' didn't convert. %s"
LOG_FMT_PROMPT_INFO = "[INFO] %s"
LOG_SUGGESTION_MANUAL_CONVERT = "Please manual convert the code, along with the code associated with it."
class ApiMatchingEnum(Enum):
"""Node edge type enum."""
NOT_API = 'not an api name'
API_INFER = 'infer api name to map'
API_STANDARD = 'api name in the correct format'
API_FOUND = 'found an api name in api list'
API_MATCHED = 'api is matched to map'
class _ConvertReport:
"""Report log of converting source code."""
def __init__(self, is_stub=False):
self._is_stub = is_stub
self._max_line = 0
self._log = [] # report log, type is (severity, line, col, msg)
def _add_log(self, severity, line, col, msg):
"""Add log."""
if self._is_stub:
return
if isinstance(line, int) and isinstance(col, int):
self._log.append((severity, line, col, msg))
if self._max_line < line:
self._max_line = line
def info(self, line, col, msg):
"""Interface to add infer log"""
self._add_log(logging.INFO, line, col, msg)
def warning(self, line, col, msg):
"""Interface to add warning log"""
self._add_log(logging.WARNING, line, col, msg)
def get_logs(self):
"""Get convert logs"""
logs = []
# sort rule: line * self._max_line + col
self._log.sort(key=lambda log: log[1] * self._max_line + log[2])
for log_info in self._log:
log_info = "line %d:%d: %s" % (log_info[1], log_info[2], log_info[3])
logs.append(log_info)
return logs
class _LineColEditVisitor(ast.NodeVisitor):
"""
Update line number and col offset of ast node.
Use the line and column number of the original code to update
the line and column number of the new code replaced with the original code.
"""
class _NodeInfo:
"""NodeInfo class definition."""
def __init__(self, node):
self.node = node
self.call_list = [] # Used to save all ast.Call node in self._node
def __init__(self):
self._dst_node_info = None
self._src_node_info = None
self._visiting = self._src_node_info # Used to point to the visiting node
def update(self, replace_with_node, src_node):
"""Update the line and column number of the new code replaced with the original code."""
replace_with_node.lineno = src_node.lineno
replace_with_node.col_offset = src_node.col_offset
self._dst_node_info = self._NodeInfo(replace_with_node)
self._src_node_info = self._NodeInfo(src_node)
self._visiting = self._src_node_info
self.visit(self._visiting.node)
self._visiting = self._dst_node_info
self.visit(self._visiting.node)
self._update_line_col()
def visit_Call(self, node):
"""Callback function when visit AST tree"""
self._visiting.call_list.append(node)
self.generic_visit(node)
def _update_line_col(self):
"""Update the line and column number information for all ast.Call node."""
dst_call_list = list(self._dst_node_info.call_list)
src_call_list = list(self._src_node_info.call_list)
len_diff = len(dst_call_list) - len(src_call_list)
# After MindSpore api replaces Torch api, more calls are generated.
# For example, out.view() is replaced with P.Reshape()(out).
# out.view() has only one call, but P.Reshape()(out) has two calls.
# To match the replaced calls, the calls of out.view is padded to the same quantity.
if len_diff > 0:
src_call_list = [src_call_list[0]] * len_diff + src_call_list
for dst_call, src_call in zip(dst_call_list, src_call_list):
dst_call.lineno = src_call.lineno
dst_call.col_offset = src_call.col_offset
if not dst_call.args:
continue
# When out.size().view(1, ...) transforms to P.Reshape()(out.size(), 1, ...),
# in this case, the column of parameter out.size() will be bigger than the following parameters.
# To ensure the sequence of parameters, adjust the column of the second parameter.
args = []
for arg in dst_call.args:
if self._check_arg2update(arg):
args.append(arg)
for arg in args:
arg.lineno = dst_call.lineno
arg.col_offset += dst_call.col_offset
@staticmethod
def _check_arg2update(arg):
# Only the col_offset of the first line code is re-counted, needs to be corrected.
# When the arg is a function call, its col_offset is handled separately.
if not isinstance(arg, ast.Call) and arg.lineno == 1:
return True
return False
class AstEditVisitor(ast.NodeVisitor):
"""AST Visitor that process function calls.
Converts function calls from torch api to MindSpore api using api mapping information.
"""
def __init__(self):
self._process_log = _ConvertReport()
self._tree = None
self._code_analyzer = None
self._stack = [] # Used to easily access the parent node
self._forward_list = {}
self._is_forward_function = False # Used to allow access the visiting function forward attribute
self._new_call_nodes = [] # Used to save new ast.call nodes
def process(self, ast_tree):
"""
Convert source code to MindSpore code.
Args:
ast_tree (AST): The root node of the source code.
"""
self.__init__()
self._tree = ast_tree
self._code_analyzer = CodeAnalyzer()
self._code_analyzer.process(self._tree)
self._forward_list = ForwardCall(self._tree).calls
# replace python function under nn.Module
self._convert_api()
# replace external reference statements
self._convert_external_reference()
def get_logs(self):
"""Get conversion report."""
return self._process_log.get_logs()
def _convert_cell(self, cell_scope):
"""
Convert a PyTorch Module class into MindSpore Cell class.
Args:
cell_scope (pasta.base.Scope): The network class definition node inherits from torch.nn.Module.
"""
cell_ast_node = cell_scope.node
line_no = cell_ast_node.lineno
logger.info("Line %3d: start converting nn.Module %s", line_no, self._code_analyzer.get_name(cell_ast_node))
class_elements = self._code_analyzer.network_definitions()['cell']
# step1. update function definition
for func_scope in class_elements.get(cell_scope, []):
self._update_function_def(func_scope)
# step2. update base name of class
self._update_base_name(cell_scope)
def _update_base_name(self, class_def_scope):
"""
Update base name of class.
Args:
class_def_scope (ast.ClassDef): Class definition node.
"""
base_name_mapping = APIAnalysisSpec.base_name_mapping
class_def_node = class_def_scope.node
base_class_nodes = class_def_scope.node.bases
# update base class name
for base_class_node in base_class_nodes:
base_name = base_class_node.attr
if base_name in APIAnalysisSpec.get_network_base_class_names():
old_code = pasta.dump(base_class_node)
if base_name in base_name_mapping:
new_code = 'nn.' + base_name_mapping[base_class_node.attr]
new_node = pasta.parse(new_code)
pasta.ast_utils.replace_child(class_def_node, base_class_node, new_node)
self._process_log.info(base_class_node.lineno, base_class_node.col_offset, LOG_FMT_CONVERT %
(old_code, new_code))
else:
self._process_log.info(base_class_node.lineno, base_class_node.col_offset, LOG_FMT_NOT_CONVERT %
(old_code, ''))
def _update_function_def(self, func_scope):
"""
Convert a PyTorch function into MindSpore function.
Args:
func_scope (pasta.base.scope.Scope): The node scope of function definition.
"""
is_forward = self._judge_forward(func_scope)
# step1. convert the content of the function.
self._convert_function(func_scope, is_forward)
# step2. replace function name if name is forward
func_ast_node = func_scope.node
old_func_name = 'forward'
new_func_name = 'construct'
if func_ast_node.name == old_func_name:
func_ast_node.name = new_func_name
self._process_log.info(func_ast_node.lineno, func_ast_node.col_offset,
LOG_FMT_CONVERT % (old_func_name, new_func_name))
def _convert_api(self):
"""Convert PyTorch api call to MindSpore api call in a function."""
tasks = []
convert_elements = self._code_analyzer.network_definitions()
for func_node_scope in convert_elements.get("functions", []):
is_forward = self._judge_forward(func_node_scope)
tasks.append((self._convert_function, (func_node_scope, is_forward)))
for class_scope in convert_elements.get("cell", []).keys():
tasks.append((self._convert_cell, (class_scope,)))
for convert_fun, args in tasks:
convert_fun(*args)
def _convert_external_reference(self):
"""Convert import statements."""
name_replace = APIAnalysisSpec.import_name_mapping
replace_imports = list(name_replace.values())
for ref_info in self._code_analyzer.external_references.values():
external_ref_info = ref_info['external_ref_info']
parent_node = ref_info['parent_node']
if parent_node is None:
continue
code = pasta.dump(parent_node)
if external_ref_info.name in APIAnalysisSpec.get_convertible_external_names():
external_ref_info = ref_info['external_ref_info']
if external_ref_info.name in name_replace.keys():
import_utils.remove_import_alias_node(self._code_analyzer.root_scope, external_ref_info.node)
replace_info = name_replace[external_ref_info.name]
new_ref_name = replace_info[1]
new_external_name = replace_info[0]
if new_ref_name:
new_code = f'import {new_external_name} as {new_ref_name}'
else:
new_code = f'import {new_external_name}'
self._process_log.info(parent_node.lineno, parent_node.col_offset, LOG_FMT_CONVERT %
(code.strip(), new_code.strip()))
elif external_ref_info.name.startswith('torch.'):
self._process_log.warning(parent_node.lineno, parent_node.col_offset, LOG_FMT_NOT_CONVERT %
(code.strip(), LOG_SUGGESTION_MANUAL_CONVERT))
else:
pass
# Insert import in reverse order, display in forward order.
for idx in range(len(replace_imports) - 1, -1, -1):
replace_import = replace_imports[idx]
if replace_import[1]:
self._add_import(name_to_import=replace_import[0], as_name=replace_import[1])
else:
self._add_import(name_to_import=replace_import[0])
def _add_import(self, name_to_import, as_name=None):
"""
Adds an import to the ast tree.
Args:
name_to_import: (string) The absolute name to import.
as_name: (string) The alias for the import ("import name_to_import as asname")
"""
new_alias = ast.alias(name=name_to_import, asname=as_name)
import_node = ast.Import(names=[new_alias])
# Insert the node at the top of the module
self._tree.body.insert(1 if pasta.base.ast_utils.has_docstring(self._tree) else 0, import_node)
def _convert_function(self, func_scope, is_forward):
"""
Convert a PyTorch function into MindSpore function.
Args:
func_scope (pasta.base.scope.Scope): The node scope of function definition.
is_forward (boolean): If the function is defined in forward function in nn.Module in torch.
"""
func_ast_node = func_scope.node
line_no = func_ast_node.lineno
logger.info("Line %3d: start converting function %s()", line_no, func_ast_node.name)
parent = func_scope.parent_scope.node
self._stack.clear()
self._new_call_nodes.clear()
if parent:
self._stack.append(parent)
self._is_forward_function = is_forward
self.visit(func_scope.node)
def _judge_forward(self, func_scope):
"""
Check if function is a forward function.
Args:
func_scope (pasta.base.scope.Scope): The node scope of function definition.
Returns:
boolean, True or False
"""
is_forward = func_scope.node in self._forward_list.values()
if is_forward:
logger.debug("%s is a forward function", self._code_analyzer.get_name(func_scope))
return is_forward
# Overridden to maintain stack information to access parent node
def visit(self, node):
"""Visit a ast tree."""
self._stack.append(node)
super(AstEditVisitor, self).visit(node)
self._stack.pop()
def _mapping_standard_api_name(self, api_name):
"""Get mapping from external reference name to standard external reference name"""
standard_name = api_name
if not self._code_analyzer.is_standard_external_ref:
# key is real ref name, value is standard ref name.
mapping_names = self._mapping_standard_external_ref()
api_name_parts = api_name.split('.')
api_name_parts[0] = mapping_names.get(api_name_parts[0], api_name_parts[0])
standard_name = '.'.join(api_name_parts)
return standard_name
def _infer_api_name(self, call_func_node, check_context=True):
"""Infer the call name.
Examples:
1. nn.Sequential inferred to nn.Sequential
2. mmm.size inferred to .size if import torch.nn as nn
3. mmm.size inferred to mmm.size if import torch.nn as mmm
"""
match_case = ApiMatchingEnum.NOT_API
api_name = None
call_name = pasta.dump(call_func_node)
is_include_sub_call = self._is_include_sub_call(call_func_node)
if is_include_sub_call:
name_attributes = call_name.rsplit('.', 1)
else:
name_attributes = call_name.split('.')
# rewritten external module name
# e.g., mm.ReLU will be written to nn.ReLU if 'import torch.nn as mm' in script.
if check_context and not self._code_analyzer.is_standard_external_ref:
standard_name = self._mapping_standard_api_name(name_attributes[0])
else:
standard_name = name_attributes[0]
if standard_name in ["nn", "F", "torch"]:
match_case = ApiMatchingEnum.API_STANDARD
api_name = call_name
else:
# only infer function for tensor object.
# e.g., api_call_name is out.view, .view is an api name for out which is maybe a tensor object.
# e.g., 'xxxx'.size can be not inferred to .size, because string is not a tensor object.
first_name = standard_name.split('.')[0]
if not re.search(r'\W', first_name) and len(name_attributes) > 1:
api_name = '.' + name_attributes[-1]
match_case = ApiMatchingEnum.API_INFER
return api_name, match_case
@staticmethod
def _is_include_sub_call(call_func_node):
""""Inspect a sub call in call expression.
Examples:
1. nn.functional.relu() return False
2. nn.functional.relu(out).size() return True. nn.functional.relu(out) is sub call.
3. nn.functional.relu(out=out.size()).size() return False. out.size() is not sub call of argument.
"""
is_include_call = False
try:
sub_node = call_func_node
while sub_node and not isinstance(sub_node, ast.Call):
sub_node = sub_node.value
if isinstance(sub_node, ast.Call):
is_include_call = True
except AttributeError:
is_include_call = False
return is_include_call
def match_api(self, call_func_node, is_forward):
"""
Check api name to convert, check api name ok with a is_forward condition.
Args:
call_func_node (ast.Attribute): The call.func node.
is_forward (bool): whether api belong to forward.
Returns:
str, the standard api name used to match.
ApiMappingEnum, the match result.
"""
api_name, match_case = self._infer_api_name(call_func_node)
api_call_name = pasta.dump(call_func_node)
is_tensor_obj_call = False
if api_name != api_call_name:
is_tensor_obj_call = True
standard_api_call_name = api_name
# rewritten external module name
# e.g., mm.ReLU will be written to nn.ReLU if 'import torch.nn as mm' in script.
if not is_tensor_obj_call and not self._code_analyzer.is_standard_external_ref:
standard_api_call_name = self._mapping_standard_api_name(api_name)
if standard_api_call_name in ALL_TORCH_APIS:
match_case = ApiMatchingEnum.API_FOUND
if (not is_forward and standard_api_call_name in NN_LIST) or \
(is_forward and standard_api_call_name in ALL_2P_LIST):
match_case = ApiMatchingEnum.API_MATCHED
return standard_api_call_name, match_case
def mapping_api(self, call_node, check_context=True):
"""
Convert api_name in code to MindSpore api, if api_name is a python api, code will not convert.
If do not check context of the script, the code represented by the node must be written in the standard way.
Args:
call_node (ast.Call): The ast node to convert.
check_context (boolean): If True, the code context will be checked. Default is True.
Returns:
str, the converted code.
"""
if not isinstance(call_node, ast.Call):
raise NodeTypeNotSupport("It is not ast.Call node.")
code = pasta.dump(call_node)
api_call_name = pasta.dump(call_node.func)
if api_call_name.startswith('self.'):
return code
# find full api expected to be converted. eg:expr="nn.Conv2d(1,2,3)" args_str="(1,2,3)"
args_str = code[len(api_call_name):].strip()
try:
api_name, _ = self._infer_api_name(call_node.func, check_context)
standard_api_call_name = api_call_name
if api_name != api_call_name:
# api name .view inferred from out.view, split tensor object name is out
tensor_obj_name = api_call_name[:-len(api_name)]
map_helper = ALL_MAPPING[api_name]
new_code = map_helper.convert(tensor_obj_name, args_str)
else:
# change to external ref name
# e.g., mm.ReLU will be changed to nn.ReLU if 'import torch.nn as mm' in script.
if check_context and not self._code_analyzer.is_standard_external_ref:
standard_api_call_name = self._mapping_standard_api_name(api_name)
map_helper = ALL_MAPPING[standard_api_call_name]
new_code = map_helper.convert(standard_api_call_name, args_str)
except KeyError:
return code
return new_code
def visit_Call(self, node):
"""Callback function when visit AST tree"""
code = pasta.dump(node)
api_name = pasta.dump(node.func)
# parent node first call is equal to this node, skip when parent node is replaced.
for parent_node in self._stack[:-1]:
if parent_node in self._new_call_nodes and pasta.dump(parent_node).startswith(api_name):
return
parent = self._stack[-2]
new_node = None
matched_api_name, match_case = self.match_api(node.func, self._is_forward_function)
if match_case in [ApiMatchingEnum.API_INFER, ApiMatchingEnum.API_MATCHED]:
if matched_api_name in ALL_MAPPING:
logger.info("Line %3d start converting API: %s", node.lineno, api_name)
new_code = self.mapping_api(node)
if new_code != code:
new_node = pasta.parse(new_code).body[0].value
# find the first call name
new_api_name = new_code[:new_code.find('(')]
self._process_log.info(node.lineno, node.col_offset, LOG_FMT_CONVERT % (api_name, new_api_name))
if matched_api_name in ALL_UNSUPPORTED:
warn_info = UNSUPPORTED_WARN_INFOS.get(api_name, '')
logger.warning("Line %3d: found unsupported API: %s%s", node.lineno, api_name, warn_info)
self._process_log.warning(node.lineno, node.col_offset, LOG_FMT_NOT_CONVERT % (api_name, warn_info))
elif match_case in [ApiMatchingEnum.API_STANDARD, ApiMatchingEnum.API_FOUND]:
self._process_log.warning(node.lineno, node.col_offset, LOG_FMT_NOT_CONVERT % (api_name, ''))
else:
pass
if parent and new_node:
update_line_col = _LineColEditVisitor()
update_line_col.update(new_node, node)
pasta.ast_utils.replace_child(parent, node, new_node)
self._new_call_nodes.append(new_node)
node = new_node
self._stack[-1] = node
try:
self.generic_visit(node)
except Exception:
logger.error('original code:%s, new code:%s', code, new_code, exc_info=True)
raise
def _mapping_standard_external_ref(self):
"""Obtain the mapping dict of mapping the external references to standard external references."""
renames = {}
external_refs = self._code_analyzer.external_references
for ref_name, ref_info in external_refs.items():
external_ref_info = ref_info['external_ref_info']
if ref_name != 'nn' and external_ref_info.name == 'torch.nn':
renames[ref_name] = 'nn'
elif ref_name != 'F' and external_ref_info.name == 'torch.nn.functional':
renames[ref_name] = 'F'
return renames
...@@ -186,25 +186,23 @@ def cli_entry(): ...@@ -186,25 +186,23 @@ def cli_entry():
mode = permissions << 6 mode = permissions << 6
os.makedirs(args.output, mode=mode, exist_ok=True) os.makedirs(args.output, mode=mode, exist_ok=True)
os.makedirs(args.report, mode=mode, exist_ok=True) os.makedirs(args.report, mode=mode, exist_ok=True)
_run(args.in_file, args.output, '', args.report) _run(args.in_file, args.output, args.report)
def _run(in_files, out_dir, in_module, report): def _run(in_files, out_dir, report):
""" """
Run converter command. Run converter command.
Args: Args:
in_files (str): The file path or directory to convert. in_files (str): The file path or directory to convert.
out_dir (str): The output directory to save converted file. out_dir (str): The output directory to save converted file.
in_module (str): The module name to convert.
report (str): The report file path. report (str): The report file path.
""" """
files_config = { files_config = {
'root_path': in_files if in_files else '', 'root_path': in_files if in_files else '',
'in_files': [], 'in_files': [],
'outfile_dir': out_dir, 'outfile_dir': out_dir,
'report_dir': report, 'report_dir': report
'in_module': in_module
} }
if os.path.isfile(in_files): if os.path.isfile(in_files):
files_config['root_path'] = os.path.dirname(in_files) files_config['root_path'] = os.path.dirname(in_files)
......
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless REQUIRED by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""code analysis module"""
import ast
import pasta
from pasta.base import scope
from mindinsight.mindconverter.common.exceptions import ScriptNotSupport
class APIAnalysisSpec:
"""API analysis specifications"""
import_name_mapping = {'torch': ['mindspore', None],
'torch.nn': ['mindspore.nn', 'nn'],
'torch.nn.functional': ['mindspore.ops.operations', 'P']}
base_name_mapping = {'Module': 'Cell',
'Sequential': 'SequentialCell'
}
@classmethod
def get_convertible_external_names(cls):
"""
Obtain the convertible external names.
The external name is the full dotted name being referenced.
"""
return cls.import_name_mapping.keys()
@staticmethod
def get_network_base_class_names():
"""Obtain the base names which network class base from"""
return ['Module',
'Sequential',
'ModuleList',
'ModuleDict',
'ParameterList',
'ParameterDict']
@staticmethod
def check_external_alias_ref(ref_name, external_name):
"""
Check 'import as' is standard.
Standard references are follow:
import torch.nn as nn
import torch.nn.functional as F
Args:
ref_name (str): The name that refers to the external_name.
external_name (str): The full dotted name being referenced. For examples:
1. 'import torch.nn as nn', torch.nn is external_name, nn is ref_name.
2. 'from torch import nn as mm, torch.nn is external_name, mm is ref_name which is not a standard name.
Returns:
boolean, True if ref_name is standard else False.
"""
if ref_name != 'nn' and external_name == 'torch.nn':
is_standard = False
elif ref_name != 'F' and external_name == 'torch.nn.functional':
is_standard = False
else:
is_standard = True
return is_standard
class CodeAnalyzer(ast.NodeVisitor):
"""Code analyzer that analyzes PyTorch python script by AST Visitor.
CodeAnalyzer find the codes that need to be converted to MindSpore,
and provides the attributes related to the codes.
"""
def __init__(self):
self._stack = [] # Used to easily access the parent node
self._external_references = {}
self._is_standard_external_ref = True
self._root_scope = None
# Used to save functions that need to be converted, value type is pasta.base.scope.Scope
self._network_functions = []
# Used to easily trace the function node
self._functions_stack = []
# key type is pasta.base.scope.Scope, value type is list
self._network_classes = {}
@property
def root_scope(self):
"""The root scope of the python script code."""
return self._root_scope
@property
def is_standard_external_ref(self):
"""Obtain whether the result is a standard external reference."""
return self._is_standard_external_ref
@property
def external_references(self):
"""Obtain all external references in the analyzed code."""
return self._external_references
def network_definitions(self):
"""Obtain the network definitions which need to be converted."""
return {"functions": self._network_functions,
"cell": self._network_classes}
def process(self, ast_tree):
"""
Start to analyze the code.
Args:
ast_tree (AST): The root node of the source code.
"""
self.__init__()
self._root_scope = scope.analyze(ast_tree)
self._pre_process()
self.visit(ast_tree)
if not self._network_classes:
msg = "model definition not be found."
raise ScriptNotSupport(msg)
@staticmethod
def _check_external_standard(external_refs):
"""Check whether all external references are standard."""
is_standard = True
for external_name, external_ref_info in external_refs.items():
is_standard = APIAnalysisSpec.check_external_alias_ref(external_name, external_ref_info.name)
if not is_standard:
break
return is_standard
def _is_base_from_cell(self, node):
"""
Check whether the node bases from cell classes which are defined in APIAnalysisSpec.
Args:
node (ast.ClassDef): The node which is a class definition.
Returns:
boolean, True if the check result is Passed else False.
"""
if self._is_ref_convertible_imports(node):
whole_name = self._get_whole_name(node)
if whole_name.split('.')[-1] in APIAnalysisSpec.get_network_base_class_names():
return True
return False
def _pre_process(self):
"""Preprocessor checks the code before analyzing."""
is_torch = False
# check whether the code imports torch.
for ref_name in self._root_scope.external_references.keys():
if ref_name.split('.')[0] in APIAnalysisSpec.get_convertible_external_names():
is_torch = True
break
if not is_torch:
msg = "The source code does not import torch, model definition can not be found."
raise ScriptNotSupport(msg)
# Find out external reference in the code and save it.
external_refs = self._analyze_import_references(self._root_scope)
self._is_standard_external_ref = self._check_external_standard(external_refs)
self._check_external_standard(external_refs)
for external_name, external_ref_info in external_refs.items():
self._external_references.update({
external_name: {
'external_ref_info': external_ref_info,
'parent_node': None
}
})
@staticmethod
def _analyze_import_references(root_scope):
"""Find out all references from the import statements."""
external_name_ref = {}
for node_references in root_scope.external_references.values():
for node_ref in node_references:
if node_ref.name_ref:
# (from)import alias, node_ref.name_ref.id is alias name
if node_ref.name_ref.definition.asname == node_ref.name_ref.id:
external_name_ref[node_ref.name_ref.id] = node_ref
# import without alias, node_ref.name_ref.definition.asname is None.
# e.g., import a.b.c, reference maybe is a, a.b or a.b.c in the root_scope.external_references.
# The reference a.b.c is really wanted.
elif node_ref.name_ref.definition.name == node_ref.name_ref.id:
external_name_ref[node_ref.name_ref.id] = node_ref
else:
pass
return external_name_ref
def visit(self, node):
"""Overridden visit of the base class to maintain stack information to access parent node."""
self._stack.append(node)
super(CodeAnalyzer, self).visit(node)
self._stack.pop()
@staticmethod
def _get_full_name(node):
"""Get the full name of the node."""
if not isinstance(node, (ast.Attribute, ast.Name)):
return None
return pasta.dump(node)
def _get_whole_name(self, node):
"""
Get the whole name of the node.
For example, nn.Module is spliced two nodes, nn node and Module node.
When visit ast nodes,
Module node is first visited, the full name is the same as the whole name, that is nn.Module.
And then nn node is visited, the full name is nn, the whole name is nn.Module.
"""
full_name = self._get_full_name(node)
if not full_name:
return None
# node is in stack top pos
if node is self._stack[-1]:
parent_index = -1
while isinstance(self._stack[parent_index], ast.Attribute):
parent_index -= 1
whole_name = self._get_full_name(self._stack[parent_index])
else:
whole_name = full_name
return whole_name
def _is_ref_convertible_imports(self, node):
"""Check whether the node references convertible imports."""
check_result = False
whole_name = self._get_whole_name(node)
if whole_name:
module_name = whole_name.split('.')[0]
for ref_name, ref_info in self._external_references.items():
external_ref = ref_info['external_ref_info']
# external reference is convertible module
if external_ref.name in APIAnalysisSpec.get_convertible_external_names():
# import from the same external module
if module_name == ref_name.split('.')[0]:
check_result = True
break
return check_result
@staticmethod
def _get_external_node(external_references):
"""Get all external reference nodes."""
external_nodes = {}
for ref_name, ref_info in external_references.items():
external_nodes.update({ref_info['external_ref_info'].node: ref_name})
return external_nodes
@staticmethod
def _get_convertible_external_node(external_name_ref):
"""Get all convertible external reference nodes."""
convertible_external_nodes = {}
for ref_name, ref_info in external_name_ref.items():
if ref_info['external_ref_info'].name in APIAnalysisSpec.get_convertible_external_names():
convertible_external_nodes.update({ref_info['external_ref_info'].node: ref_name})
return convertible_external_nodes
def _update_external_ref_parent(self, node):
"""Set external reference parent node info."""
external_nodes = self._get_external_node(self._external_references)
convertible_external_nodes = self._get_convertible_external_node(self._external_references)
for name_node in node.names:
if name_node in convertible_external_nodes.keys():
if len(node.names) > 1:
msg = """\
Not support multiple imports of torch on one line in your script. line:%s: %s
""" % (node.lineno, pasta.dump(node))
raise ScriptNotSupport(msg)
if name_node in external_nodes.keys():
ref_name = external_nodes[name_node]
self._external_references[ref_name]['parent_node'] = node
@staticmethod
def _get_class_scope(node_scope):
"""Find the class scope of the node_scope."""
parent_scope = node_scope.parent_scope
class_scope = None
while parent_scope:
if isinstance(parent_scope.node, ast.ClassDef):
class_scope = parent_scope
break
parent_scope = parent_scope.parent_scope
return class_scope
def _update_convertible_functions(self, node):
"""Update convertible functions."""
node_scope = self._root_scope.lookup_scope(node)
class_scope = self._get_class_scope(node_scope)
if class_scope:
network_classes = self._network_classes.get(class_scope, [])
if node_scope not in network_classes:
network_classes.append(node_scope)
else:
if node_scope not in self._network_functions:
self._network_functions.append(node_scope)
def visit_ClassDef(self, node):
"""Callback function when visit AST tree"""
if not self._stack[-1] is node:
return
for base in node.bases:
if self._is_ref_convertible_imports(base):
self._network_classes[self._root_scope.lookup_scope(node)] = []
self.generic_visit(node)
def visit_Import(self, node):
"""Callback function when visit AST tree"""
self._update_external_ref_parent(node)
self.generic_visit(node)
def visit_ImportFrom(self, node):
"""Callback function when visit AST tree"""
self._update_external_ref_parent(node)
self.generic_visit(node)
def visit_Call(self, node):
"""Callback function when visit AST tree"""
if not self._stack[-1] is node:
return
is_in_network_function = False
# If torch call is happened in the function, save the function for network definition.
if self._functions_stack and self._is_ref_convertible_imports(node.func):
self._update_convertible_functions(self._functions_stack[-1])
is_in_network_function = True
if not is_in_network_function:
self.generic_visit(node)
def visit_FunctionDef(self, node):
"""Callback function when visit AST tree"""
if not self._stack[-1] is node:
return
if node.name == "forward":
self._update_convertible_functions(node)
self._functions_stack.append(node)
self.generic_visit(node)
self._functions_stack.pop()
def get_name(self, node):
"""
Get the node name.
Args:
node (AST): The ast node of the source code.
Returns:
str, the name of the node
"""
if isinstance(node, pasta.base.scope.Scope):
items = [self.get_name(node.node)]
parent_scope = node.parent_scope
while parent_scope:
if not isinstance(parent_scope.node, ast.Module):
items.append(self.get_name(parent_scope.node))
parent_scope = parent_scope.parent_scope
return '.'.join(reversed(items))
if isinstance(node, (ast.ClassDef, ast.FunctionDef)):
return node.name
if isinstance(node, (ast.Name, ast.Attribute)):
return self._get_full_name(node)
return str(node)
def lookup_scope(self, node):
"""
Search the scope of the node.
Args:
node (AST): The ast node of the source code.
Returns:
scope, the scope of the node
"""
if isinstance(node, pasta.base.scope.Scope):
return node
return self._root_scope.lookup_scope(node)
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Define custom exception."""
from enum import unique
from mindinsight.utils.constant import ScriptConverterErrors
from mindinsight.utils.exceptions import MindInsightException
@unique
class ConverterErrors(ScriptConverterErrors):
"""Converter error codes."""
SCRIPT_NOT_SUPPORT = 1
NODE_TYPE_NOT_SUPPORT = 2
class ScriptNotSupport(MindInsightException):
"""The script can not support to process."""
def __init__(self, msg):
super(ScriptNotSupport, self).__init__(ConverterErrors.SCRIPT_NOT_SUPPORT,
msg,
http_code=400)
class NodeTypeNotSupport(MindInsightException):
"""The astNode can not support to process."""
def __init__(self, msg):
super(NodeTypeNotSupport, self).__init__(ConverterErrors.NODE_TYPE_NOT_SUPPORT,
msg,
http_code=400)
...@@ -13,464 +13,89 @@ ...@@ -13,464 +13,89 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""converter module""" """converter module"""
import copy
import importlib
import inspect
import os import os
import stat import stat
from mindinsight.mindconverter.config import ALL_MAPPING import pasta
from mindinsight.mindconverter.config import NN_LIST
from mindinsight.mindconverter.config import ALL_TORCH_APIS
from mindinsight.mindconverter.config import ALL_2P_LIST
from mindinsight.mindconverter.config import UNSUPPORTED_WARN_INFOS
from mindinsight.mindconverter.config import ALL_UNSUPPORTED
from mindinsight.mindconverter.common.log import logger
from mindinsight.mindconverter.forward_call import ForwardCall
LINE_NO_INDEX_DIFF = 1 from mindinsight.mindconverter.common.exceptions import ScriptNotSupport
from mindinsight.mindconverter.common.log import logger
from mindinsight.mindconverter.ast_edits import AstEditVisitor
class Converter: class Converter:
"""Convert class""" """Convert class"""
convert_info = ''
flags = os.O_WRONLY | os.O_CREAT | os.O_EXCL flags = os.O_WRONLY | os.O_CREAT | os.O_EXCL
modes = stat.S_IWUSR | stat.S_IRUSR modes = stat.S_IWUSR | stat.S_IRUSR
@staticmethod def __init__(self):
def is_local_defined(obj, member): self._tree = None
""" self._infile = None
Check if obj and member are both defined in the same source file. self._code_analyzer = None
self._ast_editor = None
Args: self._report = []
obj (Union[object, module]): A module or a class.
member (func): A function of obj.
Returns:
bool, True or False.
"""
srcfile = inspect.getsourcefile(obj)
return inspect.getsourcefile(member) == srcfile
@classmethod
def is_valid_module(cls, obj, member):
"""
Check if obj and member defined in same source file and member is inherited from torch.nn.Module.
Args:
obj (Union[object, module]): A module or a class.
member (func): A function.
Returns:
bool, True or False.
"""
if inspect.isclass(member):
is_subclass = member.__base__.__name__ in ['Module',
'Sequential',
'ModuleList',
'ModuleDict',
'ParameterList',
'ParameterDict']
return is_subclass and cls.is_local_defined(obj, member)
return False
@classmethod
def is_valid_function(cls, obj, member):
"""
Check if member is function and defined in the file same as obj.
Args:
obj (Union[object, module]: The obj.
member (func): The func.
Returns:
bool, True or False.
"""
return inspect.isfunction(member) and cls.is_local_defined(obj, member)
@staticmethod
def find_left_parentheses(string, right):
"""
Find index of the first left parenthesis.
Args:
string (str): A line of code.
right (int): The right index for string to find from.
Returns:
int, index of the first parenthesis.
Raises:
ValueError: If line of code doesn't contain any pair of `()` or `(` and `)` are not paired.
"""
if string[right] != ')':
raise ValueError('code [{}] at index {} not ")".'.format(string, right))
stack = []
for i in range(right, -1, -1):
if string[i] == ')':
stack.append(')')
elif string[i] == '(':
stack.pop()
if not stack:
return i
raise ValueError("{} should contain ()".format(string))
@staticmethod
def find_right_parentheses(string, left):
"""
Find first index of right parenthesis which make all left parenthesis make sense.
Args:
string (str): A line of code.
left (int): Start index of string to find from.
Returns:
int, index of the found right parenthesis.
Raises:
ValueError: If line of code doesn't contain any pair of `()` or `(` and `)` are not paired.
"""
stack = []
for i in range(left, len(string)):
if string[i] == '(':
stack.append('(')
elif string[i] == ')':
stack.pop()
if not stack:
return i
raise ValueError("{} should contain ()".format(string))
@staticmethod
def get_call_name(code, end):
"""
Traverse code in a reversed function from index end and get the call name and start index of the call name,
if call name not found, return a null character string and -1
Args:
code (str): The str of code to find from.
end (int): Start index to find.
Returns:
tuple(str, int), one is founded api name if found, else a null character string, the other is start index
of founded api name, -1 if api name not found
"""
stack = []
for i in range(end - 1, -1, -1):
if code[i] in ["(", "[", "{"]:
if stack:
stack.pop()
else:
return code[i + 1:end], i + 1
elif code[i] in [")", "]", "}"]:
stack.append(code[i])
elif stack:
continue
elif not (code[i].isalpha() or code[i].isdigit() or code[i] == '_' or code[i] == '.'):
return code[i + 1:end], i + 1
return "", -1
def convert_api(self, code, start, api_name=""):
"""
Convert api_name in code to MindSpore api with start as a start index, if api_name is a python api,
code will not convert.
Args:
code (str): The str code to convert.
start (int): The index of code to start convert from.
api_name (str): The api name to convert.
Returns:
str, the converted code.
int, index of converted api_name in code.
"""
# handle format like .shape(
if api_name.startswith('.'):
call_name, new_start = self.get_call_name(code, start)
if start == -1 or call_name == "self":
return code, start + 1
else:
call_name = api_name
new_start = start
# find full api expected to be converted. eg:expr="nn.Conv2d(1,2,3)" args_str="(1,2,3)"
left = code.find("(", start)
if left == -1:
raise ValueError('"(" not found, {} should work with "("'.format(call_name))
right = self.find_right_parentheses(code, left)
end = right
expr = code[start:end + 1]
args_str = code[left:right + 1]
map_helper = ALL_MAPPING[api_name]
new_expr = map_helper.convert(call_name, args_str)
next_newline = code.find("\n", end + 1)
fill_num = (expr.count("\n") - new_expr.count("\n"))
if next_newline != -1:
code = code[:new_start] + new_expr + code[end + 1:next_newline] + ("\n" * fill_num) + code[next_newline:]
else:
code = code[:new_start] + new_expr + ")" + ("\n" * fill_num) + code[end + 2:]
return code, start + len(map_helper.ms_api.name)
@staticmethod
def find_api(code, i, is_forward):
"""
Find api name from code with a start index i, check api name ok with a is_forward condition.
Args:
code (str): The code from which to find api name.
i (int): The start index to find.
is_forward (bool): Check if the found api name ok.
Returns:
str, api name if find api name and check ok with is_forward condition, else a null character string.
"""
if code[i:].startswith("nn.") \
or code[i:].startswith("F.") \
or code[i:].startswith("torch.") \
or code[i:].startswith('.'):
j = code.find('(', i)
if j != -1 and code[i:j] in ALL_TORCH_APIS:
api_name = code[i:j]
if (not is_forward and api_name in NN_LIST) or (is_forward and api_name in ALL_2P_LIST):
return api_name
return ""
def convert_function(self, fun_name, fun, is_forward):
"""
Convert a PyTorch function into MindSpore function.
Args:
fun_name (str): The str of function name.
fun (func): The function to convert.
is_forward (bool): If the function is defined in forward function in nn.Module in torch.
Returns:
dict, old code and converted code map if convert happens, else {}.
"""
_, line_no = inspect.getsourcelines(fun)
logger.info("Line %3d: start converting function %s()", line_no, fun_name)
code = inspect.getsource(fun)
code_saved = copy.copy(code)
i = 0
while i < len(code):
api_name = self.find_api(code, i, is_forward)
if api_name:
line_no1 = line_no + code[:i].count('\n')
if api_name in ALL_MAPPING:
logger.info("Line %3d start converting API: %s", line_no1, api_name)
code, i = self.convert_api(code, i, api_name)
self.convert_info += "[Convert][Line{:3d}] {} is converted.\n".format(line_no1, api_name)
continue
if api_name in ALL_UNSUPPORTED:
warn_info = ". " + UNSUPPORTED_WARN_INFOS[api_name] if api_name in UNSUPPORTED_WARN_INFOS else ""
logger.warning("Line %3d: found unsupported API: %s%s", line_no1, api_name, warn_info)
self.convert_info += "[Unconvert][Line{:3d}] {} didn't convert{}\n".format(line_no1,
api_name, warn_info)
i += 1
return {code_saved: code} if code_saved != code else {}
@staticmethod
def judge_forward(name, forward_list):
"""
Check if function is a forward function.
Args:
name (str): The function name.
forward_list (set): A set of forward function.
Returns:
bool, True or False
"""
is_forward = name in forward_list or name.split(".")[-1] == "forward"
if is_forward:
logger.debug("%s is a forward function", name)
return is_forward
def convert_module(self, module_name, module, forward_list):
"""
Convert a PyTorch module code into MindSpore module code.
Args:
module_name (str): The module's name.
module (module): The module to convert.
forward_list (set): A set of forward function.
Returns:
dict, map of old code and converted code.
"""
_, line_no = inspect.getsourcelines(module)
logger.info("Line {:3d}: start converting nn.Module {}".format(line_no, module_name))
mapped = {}
for name, member in inspect.getmembers(module):
if self.is_valid_function(module, member):
is_forward = self.judge_forward("{}.{}".format(module_name, name), forward_list)
mapped.update(self.convert_function(name, member, is_forward))
return mapped
def get_mapping(self, import_mod, forward_list):
"""
Convert code of a module and get mapping of old code and convert code.
Args:
import_mod (module): The module to convert.
forward_list (set): A set of forward function.
Returns:
dict, mapping for old code and converted code of the module
"""
mapping = {}
tasks = []
for name, member in inspect.getmembers(import_mod):
if self.is_valid_module(import_mod, member):
_, line_no = inspect.getsourcelines(member)
tasks.append((line_no, self.convert_module, (name, member, forward_list)))
elif self.is_valid_function(import_mod, member):
_, line_no = inspect.getsourcelines(member)
is_forward = self.judge_forward("{}.{}".format(import_mod, name), forward_list)
tasks.append((line_no, self.convert_function, (name, member, is_forward)))
tasks.sort()
for _, convert_fun, args in tasks:
mapping.update(convert_fun(*args))
return mapping
@staticmethod
def get_code_start_line_num(source_lines):
"""
Get the start code line number exclude comments.
Args:
source_lines (list[str]): Split results of original code.
Returns:
int, the start line number.
"""
stack = []
index = 0
for i, line in enumerate(source_lines):
if line.strip().startswith('#'):
continue
if line.strip().startswith('"""'):
if not line.endswith('"""\n'):
stack.append('"""')
continue
if line.strip().startswith("'''"):
if not line.endswith("'''\n"):
stack.append("'''")
continue
if line.endswith('"""\n') or line.endswith("'''\n"):
stack.pop()
continue
if line.strip() != '' and not stack:
index = i
break
return index
def update_code_and_convert_info(self, code, mapping):
"""
Replace code according to mapping, and update convert info.
Args:
code (str): The code to replace.
mapping (dict): Mapping for original code and the replaced code.
Returns:
str, the replaced code.
"""
for key, value in mapping.items():
code = code.replace(key, value)
source_lines = code.splitlines(keepends=True)
start_line_number = self.get_code_start_line_num(source_lines)
add_import_infos = ['import mindspore\n',
'import mindspore.nn as nn\n',
'import mindspore.ops.operations as P\n']
for i, add_import_info in enumerate(add_import_infos):
source_lines.insert(start_line_number + i, add_import_info)
self.convert_info += '[Add Import] {}.\n'.format(add_import_info.strip())
insert_count = len(add_import_infos)
line_diff = insert_count - LINE_NO_INDEX_DIFF
for i in range(start_line_number + insert_count, len(source_lines)): def convert(self, infile, output_dir, report_dir):
line = source_lines[i]
if (line.startswith('from torch') and 'import' in line) or line.startswith('import torch'):
new_line = '# ' + line
source_lines[i] = new_line
self.convert_info += '[Annotate][Line{:3d}] {} is annotated.\n'.format(i - line_diff, line.strip())
if line.strip().startswith('class') and '(nn.Module)' in line:
new_line = line.replace('nn.Module', 'nn.Cell')
source_lines[i] = new_line
self.convert_info += '[Convert][Line{:3d}] nn.Module is converted.\n'.format(i - line_diff)
if line.strip().startswith('def forward('):
new_line = line.replace('forward', 'construct')
source_lines[i] = new_line
self.convert_info += '[Convert][Line{:3d}] forward is converted.\n'.format(i - line_diff)
if 'nn.Linear' in line:
new_line = line.replace('nn.Linear', 'nn.Dense')
source_lines[i] = new_line
self.convert_info += '[Convert][Line{:3d}] nn.Linear is converted.\n'.format(i - line_diff)
if '(nn.Sequential)' in line:
new_line = line.replace('nn.Sequential', 'nn.SequentialCell')
source_lines[i] = new_line
self.convert_info += '[Convert][Line{:3d}] nn.Sequential is converted.\n'.format(i - line_diff)
if 'nn.init.' in line:
new_line = line.replace('nn.init', 'pass # nn.init')
source_lines[i] = new_line
self.convert_info += '[Annotate][Line{:3d}] {} is annotated.\n'.format(i - line_diff, 'nn.init')
code = ''.join(source_lines)
return code
def convert(self, import_name, output_dir, report_dir):
""" """
Convert a module's code, code converted will be save in output_dir, and a report will be save in report_dir. Convert a module's code, code converted will be save in output_dir, and a report will be save in report_dir.
Args: Args:
import_name (str): The module from which to import the module to convert. infile (str): The script to convert.
output_dir (str): The path to save converted file. output_dir (str): The path to save converted file.
report_dir (str): The path to save report file. report_dir (str): The path to save report file.
""" """
logger.info("Start converting %s", import_name) in_file_split = _path_split(infile)
start_info = '[Start Convert]\n' in_file_split[-1], _ = _get_name_ext(in_file_split[-1])
module_info = 'The module is {}.\n'.format(import_name) module_name = '.'.join(in_file_split)
with open(infile, 'r') as file:
import_mod = importlib.import_module(import_name) content = ''.join(file.readlines())
srcfile = inspect.getsourcefile(import_mod)
logger.info("Script file is %s", srcfile) self._infile = infile
self._tree = pasta.parse(content)
forward_list = set(ForwardCall(srcfile).calls) self._report.clear()
logger.debug("Forward_list: %s", forward_list) try:
logger.info("Script file is %s", infile)
# replace python function under nn.Module logger.info("Start converting %s", module_name)
mapping = self.get_mapping(import_mod, forward_list) self._report.append('[Start Convert]')
code = inspect.getsource(import_mod) self._ast_editor = AstEditVisitor()
code = self.update_code_and_convert_info(code, mapping) self._ast_editor.process(self._tree)
convert_info_split = self.convert_info.splitlines(keepends=True) self._report.extend(self._ast_editor.get_logs())
convert_info_split = sorted(convert_info_split) self._report.append('[Convert Over]')
convert_info_split.insert(0, start_info) dest_file = os.path.join(output_dir, os.path.basename(infile))
convert_info_split.insert(1, module_info)
convert_info_split.append('[Convert Over]')
self.convert_info = ''.join(convert_info_split)
dest_file = os.path.join(output_dir, os.path.basename(srcfile))
with os.fdopen(os.open(dest_file, self.flags, self.modes), 'w') as file: with os.fdopen(os.open(dest_file, self.flags, self.modes), 'w') as file:
file.write(code) file.write(pasta.dump(self._tree))
logger.info("Convert success. Result is wrote to %s.", dest_file) logger.info("Convert success. Result is wrote to %s.", dest_file)
except ScriptNotSupport as error:
self._report.append('[ScriptNotSupport] ' + error.message)
self._report.append('[Convert failed]')
raise error
except Exception as error:
self._report.clear()
raise error
finally:
if self._report:
dest_report_file = os.path.join(report_dir, dest_report_file = os.path.join(report_dir,
'_'.join(os.path.basename(srcfile).split('.')[:-1]) + '_report.txt') '_'.join(os.path.basename(infile).split('.')[:-1]) + '_report.txt')
with os.fdopen(os.open(dest_report_file, self.flags, self.modes), 'a') as file: with os.fdopen(os.open(dest_report_file, self.flags, self.modes), 'a') as file:
file.write(self.convert_info) file.write('\n'.join(self._report))
logger.info("Convert report is saved in %s", dest_report_file) logger.info("Convert report is saved in %s", dest_report_file)
@staticmethod
def convert_api(source_code):
"""
Convert api_name in code to MindSpore api, if api_name is a python api, code will not convert.
Args:
source_code (ast.Call): The ast node to convert.
Returns:
str, the converted code.
"""
ast_node = pasta.parse(source_code).body[0].value
check_context = False
replaced_code = AstEditVisitor().mapping_api(ast_node, check_context)
return replaced_code
def _get_name_ext(file): def _get_name_ext(file):
""" """
...@@ -514,14 +139,6 @@ def main(files_config): ...@@ -514,14 +139,6 @@ def main(files_config):
files_config (dict): The config of files which to convert. files_config (dict): The config of files which to convert.
""" """
convert_ins = Converter() convert_ins = Converter()
root_path = files_config['root_path']
in_files = files_config['in_files'] in_files = files_config['in_files']
for in_file in in_files: for in_file in in_files:
in_file_split = _path_split(in_file[len(root_path):]) convert_ins.convert(in_file, files_config['outfile_dir'], files_config['report_dir'])
in_file_split[-1], _ = _get_name_ext(in_file_split[-1])
module_name = '.'.join(in_file_split)
convert_ins.convert(module_name, files_config['outfile_dir'], files_config['report_dir'])
in_module = files_config.get('in_module')
if in_module:
convert_ins.convert(in_module, files_config['outfile_dir'], files_config['report_dir'])
...@@ -14,7 +14,8 @@ ...@@ -14,7 +14,8 @@
# ============================================================================ # ============================================================================
"""Find out forward functions of script file""" """Find out forward functions of script file"""
import ast import ast
import os
import pasta
class ForwardCall(ast.NodeVisitor): class ForwardCall(ast.NodeVisitor):
...@@ -24,73 +25,80 @@ class ForwardCall(ast.NodeVisitor): ...@@ -24,73 +25,80 @@ class ForwardCall(ast.NodeVisitor):
Find the sub functions called by the forward function in the script file. Find the sub functions called by the forward function in the script file.
""" """
def __init__(self, filename): def __init__(self, ast_tree):
self.filename = filename self._tree = ast_tree
self.module_name = os.path.basename(filename).replace('.py', '') self._name_stack = []
self.name_stack = [] self._forward_stack = []
self.forward_stack = [] self.calls = {} # key is function name, value is forward function ast node.
self.calls = set() self._function_list = {} # key is function name, value is function ast node.
self.process() self.process()
def process(self): def process(self):
"""Parse the python source file to find the forward functions.""" """visit ast tree to find the forward functions."""
with open(self.filename, 'rt', encoding='utf-8') as file: self.visit(self._tree)
content = file.read() # first visit to find out all functions, so restores all variables except _function_list
self.visit(ast.parse(content, self.filename)) self._name_stack.clear()
self._forward_stack.clear()
self.calls.clear()
self.visit(self._tree)
def get_current_namespace(self): def get_current_namespace(self):
"""Get the namespace when visit the AST node""" """Get the namespace when visit the AST node"""
namespace = '.'.join(self.name_stack) namespace = '.'.join(self._name_stack)
return namespace return namespace
@classmethod @classmethod
def get_ast_node_name(cls, node): def get_call_name(cls, node):
"""Get AST node name.""" """Get functional call name."""
if isinstance(node, ast.Attribute): if not isinstance(node, ast.Call):
return f'{cls.get_ast_node_name(node.value)}.{node.attr}' return None
if isinstance(node, ast.Name):
return node.id
return node return pasta.dump(node.func)
def visit_ClassDef(self, node): def visit_ClassDef(self, node):
"""Callback function when visit AST tree""" """Callback function when visit AST tree"""
self.name_stack.append(node.name) self._name_stack.append(node.name)
self.generic_visit(node) self.generic_visit(node)
self.name_stack.pop() self._name_stack.pop()
def visit_FunctionDef(self, node): def visit_FunctionDef(self, node):
"""Callback function when visit AST tree""" """Callback function when visit AST tree"""
namespace = self.get_current_namespace()
if namespace:
func_name = f'{namespace}.{node.name}'
else:
func_name = node.name
func_name = f'{self.get_current_namespace()}.{node.name}' func_name = f'{self.get_current_namespace()}.{node.name}'
is_in_chain = func_name in self.calls or node.name == 'forward' is_in_chain = func_name in self.calls or node.name == 'forward'
if is_in_chain: if is_in_chain:
self.forward_stack.append(func_name) self._forward_stack.append(func_name)
if node.name == 'forward': if node.name == 'forward':
self.calls.add(func_name) self.calls.update({func_name: node})
self._function_list.update({func_name: node})
self.generic_visit(node) self.generic_visit(node)
if is_in_chain: if is_in_chain:
self.forward_stack.pop() self._forward_stack.pop()
def visit_Call(self, node): def visit_Call(self, node):
"""Callback function when visit AST tree""" """Callback function when visit AST tree"""
for arg in node.args: for arg in node.args:
self.visit(arg) self.visit(arg)
for kw in node.keywords: for keyword in node.keywords:
self.visit(kw.value) self.visit(keyword.value)
func_name = self.get_ast_node_name(node.func) func_name = self.get_call_name(node)
if isinstance(node.func, ast.Name): if isinstance(node.func, ast.Name):
if func_name not in ['super', 'str', 'repr']: if func_name not in ['super', 'str', 'repr']:
if self.forward_stack: if self._forward_stack:
self.calls.add(func_name) self.calls.update({func_name: self._function_list.get(func_name)})
self.visit(node.func) self.visit(node.func)
else: else:
if self.forward_stack: if self._forward_stack:
if 'self' in func_name: if func_name.startswith('self.'):
self.calls.add(f'{self.get_current_namespace()}.{func_name.split(".")[-1]}') whole_name = f'{self.get_current_namespace()}.{func_name.split(".")[-1]}'
self.calls.update({whole_name: self._function_list.get(whole_name)})
else: else:
self.calls.add(func_name) self.calls.update({func_name: self._function_list.get(func_name)})
self.visit(node.func) self.visit(node.func)
...@@ -30,6 +30,7 @@ class MindInsightModules(Enum): ...@@ -30,6 +30,7 @@ class MindInsightModules(Enum):
LINEAGEMGR = 2 LINEAGEMGR = 2
DATAVISUAL = 5 DATAVISUAL = 5
PROFILERMGR = 6 PROFILERMGR = 6
SCRIPTCONVERTER = 7
class GeneralErrors(Enum): class GeneralErrors(Enum):
...@@ -69,3 +70,7 @@ class DataVisualErrors(Enum): ...@@ -69,3 +70,7 @@ class DataVisualErrors(Enum):
SCALAR_NOT_EXIST = 14 SCALAR_NOT_EXIST = 14
HISTOGRAM_NOT_EXIST = 15 HISTOGRAM_NOT_EXIST = 15
TRAIN_JOB_DETAIL_NOT_IN_CACHE = 16 TRAIN_JOB_DETAIL_NOT_IN_CACHE = 16
class ScriptConverterErrors(Enum):
"""Enum definition for mindconverter errors."""
...@@ -22,380 +22,201 @@ class TestConverter: ...@@ -22,380 +22,201 @@ class TestConverter:
converter_ins = Converter() converter_ins = Converter()
def test_judge_forward(self):
"""test judge_forward"""
name1 = 'conv1'
forward_list = {'conv1', 'relu'}
result1 = self.converter_ins.judge_forward(name1, forward_list)
assert result1 is True
name2 = 'self.forward'
result2 = self.converter_ins.judge_forward(name2, forward_list)
assert result2 is True
def test_find_left_parentheses(self):
"""test find_left_parentheses"""
code = '''nn.Sequential(nn.Conv2d(in_dim, 6, 5, stride=1, padding=0, ),
nn.ReLU(),
nn.ReLU(True),
nn.MaxPool2d(2, 2),
nn.Conv2d(6, 16, 5, stride=1, padding=0),
nn.ReLU(inplace=False),
nn.MaxPool2d(2, 2))'''
right_index = len(code) - 1
left_index = code.index('nn.Conv2d')
result = self.converter_ins.find_left_parentheses(code, right_index)
assert result == left_index - 1
def test_find_api(self):
"""test find_api"""
code = '''nn.Sequential(nn.Conv2d(in_dim, 6, 5, stride=1, padding=0, ),
nn.ReLU(),
nn.ReLU(True),
nn.MaxPool2d(2, 2), # TODO padding
nn.Conv2d(6, 16, 5, stride=1, padding=0),
nn.ReLU(inplace=False),
nn.MaxPool2d(2, 2))'''
index = 0
is_forward = False
result = self.converter_ins.find_api(code, index, is_forward)
assert result == 'nn.Sequential'
def test_get_call_name(self):
"""test get_call_name"""
code = '''nn.Sequential(nn.Conv2d(in_dim, 6, 5, stride=1, padding=0))'''
end = len(code)
call_name, index = self.converter_ins.get_call_name(code, end)
assert call_name == ''
assert index == -1
def test_find_right_parentheses(self):
"""test find_right_parentheses"""
code = '''nn.Sequential(nn.Conv2d(in_dim, 6, 5, stride=1, padding=0, ),
nn.ReLU(),
nn.ReLU(True),
nn.MaxPool2d(2, 2), # TODO padding
nn.Conv2d(6, 16, 5, stride=1, padding=0),
nn.ReLU(inplace=False),
nn.MaxPool2d(2, 2))'''
left_index = 0
result = self.converter_ins.find_right_parentheses(code, left_index)
assert_index = len(code) - 1
assert result == assert_index
# test convert_api with nn ops # test convert_api with nn ops
def test_convert_api_nn_layernorm(self): def test_convert_api_nn_layernorm(self):
"""Test convert_api function work ok when convert api nn.LayerNorm""" """Test convert_api function work ok when convert api nn.LayerNorm"""
code = """ code = "nn.LayerNorm((5, 10, 10), elementwise_affine=False)"
def __init__(self, num_classes=1000):
self.features = nn.SequentialCell([
nn.LayerNorm((5, 10, 10), elementwise_affine=False),
nn.ReLU(inplace=False)
])
"""
api_name = 'nn.LayerNorm' api_name = 'nn.LayerNorm'
start = code.find(api_name)
layer_norm_info = NN_MAPPING.get(api_name) layer_norm_info = NN_MAPPING.get(api_name)
expected_ms_api_name = 'nn.LayerNorm' expected_ms_api_name = 'nn.LayerNorm'
epsilon = layer_norm_info.pt_api.params.get('eps') epsilon = layer_norm_info.pt_api.params.get('eps')
replaced_code, new_start = self.converter_ins.convert_api(code, start, api_name) replaced_code = self.converter_ins.convert_api(code)
assert replaced_code == code.replace('nn.LayerNorm((5, 10, 10), elementwise_affine=False)', assert replaced_code == code.replace('nn.LayerNorm((5, 10, 10), elementwise_affine=False)',
'{}(normalized_shape=(5, 10, 10), epsilon={})'.format( '{}(normalized_shape=(5, 10, 10), epsilon={})'.format(
expected_ms_api_name, epsilon)) expected_ms_api_name, epsilon))
assert new_start == start + len(expected_ms_api_name)
def test_convert_api_nn_leaky_relu(self): def test_convert_api_nn_leaky_relu(self):
"""Test convert_api function work ok when convert api nn.LeakyReLU""" """Test convert_api function work ok when convert api nn.LeakyReLU"""
code = """ code = "nn.LeakyReLU(0.3)"
def __init__(self, num_classes=1000):
self.features = nn.SequentialCell([
nn.LayerNorm((5, 10, 10), elementwise_affine=False),
nn.LeakyReLU(0.3)])
"""
api_name = 'nn.LeakyReLU'
start = code.find(api_name)
expected_ms_api_name = 'nn.LeakyReLU' expected_ms_api_name = 'nn.LeakyReLU'
replaced_code, new_start = self.converter_ins.convert_api(code, start, api_name) replaced_code = self.converter_ins.convert_api(code)
assert replaced_code == code.replace('nn.LeakyReLU(0.3)', assert replaced_code == code.replace('nn.LeakyReLU(0.3)',
'{}(alpha=0.3)'.format(expected_ms_api_name)) '{}(alpha=0.3)'.format(expected_ms_api_name))
assert new_start == start + len(expected_ms_api_name)
def test_convert_api_nn_prelu(self): def test_convert_api_nn_prelu(self):
"""Test convert_api function work ok when convert api nn.PReLU""" """Test convert_api function work ok when convert api nn.PReLU"""
code = """ code = "nn.PReLU()(input)"
input = torch.randn(2, 3, 5)
nn.PReLU()(input)
"""
api_name = 'nn.PReLU'
start = code.find(api_name)
expected_ms_api_name = 'nn.PReLU' expected_ms_api_name = 'nn.PReLU'
replaced_code, new_start = self.converter_ins.convert_api(code, start, api_name) replaced_code = self.converter_ins.convert_api(code)
assert replaced_code == code.replace('nn.PReLU()(input)', assert replaced_code == code.replace('nn.PReLU()(input)',
'{}()(input)'.format(expected_ms_api_name)) '{}()(input)'.format(expected_ms_api_name))
assert new_start == start + len(expected_ms_api_name)
def test_convert_api_nn_softmax(self): def test_convert_api_nn_softmax(self):
"""Test convert_api function work ok when convert api nn.Softmax""" """Test convert_api function work ok when convert api nn.Softmax"""
code = """ code = "nn.Softmax(dim=1)"
nn.Softmax(dim=1)(input)
"""
api_name = 'nn.Softmax'
expected_ms_api_name = 'nn.Softmax' expected_ms_api_name = 'nn.Softmax'
start = code.find(api_name)
replaced_code, new_start = self.converter_ins.convert_api(code, start, api_name) replaced_code = self.converter_ins.convert_api(code)
assert replaced_code == code.replace('nn.Softmax(dim=1)(input)', assert replaced_code == code.replace('nn.Softmax(dim=1)',
'{}(axis=1)(input)'.format(expected_ms_api_name)) '{}(axis=1)'.format(expected_ms_api_name))
assert new_start == start + len(expected_ms_api_name)
# test convert_api with torch dot ops # test convert_api with torch dot ops
def test_convert_api_torch_dot_abs(self): def test_convert_api_torch_dot_abs(self):
"""Test convert_api function work ok when convert api torch.abs""" """Test convert_api function work ok when convert api torch.abs"""
code = """ code = "torch.abs(input)"
torch.abs(input)
"""
api_name = 'torch.abs'
start = code.find(api_name)
expected_ms_api_name = 'P.Abs' expected_ms_api_name = 'P.Abs'
replaced_code, new_start = self.converter_ins.convert_api(code, start, api_name) replaced_code = self.converter_ins.convert_api(code)
assert replaced_code == code.replace('torch.abs(input)', assert replaced_code == code.replace('torch.abs(input)',
'{}()(input)'.format(expected_ms_api_name)) '{}()(input)'.format(expected_ms_api_name))
assert new_start == start + len(expected_ms_api_name)
def test_convert_api_torch_dot_acos(self): def test_convert_api_torch_dot_acos(self):
"""Test convert_api function work ok when convert api torch.acos""" """Test convert_api function work ok when convert api torch.acos"""
code = """ code = "torch.acos(input)"
torch.acos(input)
"""
api_name = 'torch.acos'
start = code.find(api_name)
expected_ms_api_name = 'P.ACos' expected_ms_api_name = 'P.ACos'
replaced_code, new_start = self.converter_ins.convert_api(code, start, api_name) replaced_code = self.converter_ins.convert_api(code)
assert replaced_code == code.replace('torch.acos(input)', assert replaced_code == code.replace('torch.acos(input)',
'{}()(input)'.format(expected_ms_api_name)) '{}()(input)'.format(expected_ms_api_name))
assert new_start == start + len(expected_ms_api_name)
def test_convert_api_torch_dot_cos(self): def test_convert_api_torch_dot_cos(self):
"""Test convert_api function work ok when convert api torch.cos""" """Test convert_api function work ok when convert api torch.cos"""
code = """ code = "torch.cos(input)"
torch.cos(input)
"""
api_name = 'torch.cos'
expected_ms_api_name = 'P.Cos' expected_ms_api_name = 'P.Cos'
start = code.find(api_name)
replaced_code, new_start = self.converter_ins.convert_api(code, start, api_name) replaced_code = self.converter_ins.convert_api(code)
assert replaced_code == code.replace('torch.cos(input)', assert replaced_code == code.replace('torch.cos(input)',
'{}()(input)'.format(expected_ms_api_name)) '{}()(input)'.format(expected_ms_api_name))
assert new_start == start + len(expected_ms_api_name)
def test_convert_api_torch_dot_exp(self): def test_convert_api_torch_dot_exp(self):
"""Test convert_api function work ok when convert api torch.exp""" """Test convert_api function work ok when convert api torch.exp"""
code = """ code = "torch.exp(input)"
torch.exp(input)
"""
api_name = 'torch.exp'
expected_ms_api_name = 'P.Exp' expected_ms_api_name = 'P.Exp'
start = code.find(api_name)
replaced_code, new_start = self.converter_ins.convert_api(code, start, api_name) replaced_code = self.converter_ins.convert_api(code)
assert replaced_code == code.replace('torch.exp(input)', assert replaced_code == code.replace('torch.exp(input)',
'{}()(input)'.format(expected_ms_api_name)) '{}()(input)'.format(expected_ms_api_name))
assert new_start == start + len(expected_ms_api_name)
def test_convert_api_torch_dot_log(self): def test_convert_api_torch_dot_log(self):
"""Test convert_api function work ok when convert api torch.log""" """Test convert_api function work ok when convert api torch.log"""
code = """ code = "torch.log(input)"
torch.log(input)
"""
api_name = 'torch.log'
expected_ms_api_name = 'P.Log' expected_ms_api_name = 'P.Log'
start = code.find(api_name)
replaced_code, new_start = self.converter_ins.convert_api(code, start, api_name) replaced_code = self.converter_ins.convert_api(code)
assert replaced_code == code.replace('torch.log(input)', assert replaced_code == code.replace('torch.log(input)',
'{}()(input)'.format(expected_ms_api_name)) '{}()(input)'.format(expected_ms_api_name))
assert new_start == start + len(expected_ms_api_name)
def test_convert_api_torch_dot_pow(self): def test_convert_api_torch_dot_pow(self):
"""Test convert_api function work ok when convert api torch.pow""" """Test convert_api function work ok when convert api torch.pow"""
code = """ code = "torch.pow(a, exp)"
torch.pow(a, exp)
"""
api_name = 'torch.pow'
expected_ms_api_name = 'P.Pow' expected_ms_api_name = 'P.Pow'
start = code.find(api_name)
replaced_code, new_start = self.converter_ins.convert_api(code, start, api_name) replaced_code = self.converter_ins.convert_api(code)
assert replaced_code == code.replace('torch.pow(a, exp)', assert replaced_code == code.replace('torch.pow(a, exp)',
'{}()(a, exp)'.format(expected_ms_api_name)) '{}()(a, exp)'.format(expected_ms_api_name))
assert new_start == start + len(expected_ms_api_name)
def test_convert_api_torch_dot_div(self): def test_convert_api_torch_dot_div(self):
"""Test convert_api function work ok when convert api torch.div""" """Test convert_api function work ok when convert api torch.div"""
code = """ code = "torch.div(input, other)"
input = torch.randn(5)
other = torch.randn(5)
torch.div(input, other)
"""
api_name = 'torch.div'
expected_ms_api_name = 'P.Div' expected_ms_api_name = 'P.Div'
start = code.find(api_name)
replaced_code, new_start = self.converter_ins.convert_api(code, start, api_name) replaced_code = self.converter_ins.convert_api(code)
assert replaced_code == code.replace('torch.div(input, other)', assert replaced_code == code.replace('torch.div(input, other)',
'{}()(input, other)'.format(expected_ms_api_name)) '{}()(input, other)'.format(expected_ms_api_name))
assert new_start == start + len(expected_ms_api_name)
def test_convert_api_torch_dot_sin(self): def test_convert_api_torch_dot_sin(self):
"""Test convert_api function work ok when convert api torch.sin""" """Test convert_api function work ok when convert api torch.sin"""
code = """ code = "torch.sin(input)"
torch.sin(input)
"""
api_name = 'torch.sin'
expected_ms_api_name = 'P.Sin' expected_ms_api_name = 'P.Sin'
start = code.find(api_name)
replaced_code, new_start = self.converter_ins.convert_api(code, start, api_name) replaced_code = self.converter_ins.convert_api(code)
assert replaced_code == code.replace('torch.sin(input)', assert replaced_code == code.replace('torch.sin(input)',
'{}()(input)'.format(expected_ms_api_name)) '{}()(input)'.format(expected_ms_api_name))
assert new_start == start + len(expected_ms_api_name)
def test_convert_api_torch_dot_sqrt(self): def test_convert_api_torch_dot_sqrt(self):
"""Test convert_api function work ok when convert api torch.sqrt""" """Test convert_api function work ok when convert api torch.sqrt"""
code = """ code = "torch.sqrt(input)"
torch.sqrt(input)
"""
api_name = 'torch.sqrt'
expected_ms_api_name = 'P.Sqrt' expected_ms_api_name = 'P.Sqrt'
start = code.find(api_name)
replaced_code, new_start = self.converter_ins.convert_api(code, start, api_name) replaced_code = self.converter_ins.convert_api(code)
assert replaced_code == code.replace('torch.sqrt(input)', assert replaced_code == code.replace('torch.sqrt(input)',
'{}()(input)'.format(expected_ms_api_name)) '{}()(input)'.format(expected_ms_api_name))
assert new_start == start + len(expected_ms_api_name)
def test_convert_api_torch_dot_eye_with_n(self): def test_convert_api_torch_dot_eye_with_n(self):
"""Test convert_api function work ok when convert api torch.eye""" """Test convert_api function work ok when convert api torch.eye"""
code = """ code = "torch.eye(3)"
torch.eye(3)
"""
api_name = 'torch.eye'
expected_ms_api_name = 'P.Eye' expected_ms_api_name = 'P.Eye'
start = code.find(api_name)
replaced_code, new_start = self.converter_ins.convert_api(code, start, api_name) replaced_code = self.converter_ins.convert_api(code)
assert replaced_code == code.replace('torch.eye(3)', assert replaced_code == code.replace('torch.eye(3)',
'{}()(3, 3, mindspore.int32)'.format(expected_ms_api_name)) '{}()(3, 3, mindspore.int32)'.format(expected_ms_api_name))
assert new_start == start + len(expected_ms_api_name)
def test_convert_api_torch_dot_eye_with_m(self): def test_convert_api_torch_dot_eye_with_m(self):
"""Test convert_api function work ok when convert api torch.eye""" """Test convert_api function work ok when convert api torch.eye"""
code = """ code = "torch.eye(3, 4)"
torch.eye(3, 4)
"""
api_name = 'torch.eye'
expected_ms_api_name = 'P.Eye' expected_ms_api_name = 'P.Eye'
start = code.find(api_name)
replaced_code, new_start = self.converter_ins.convert_api(code, start, api_name) replaced_code = self.converter_ins.convert_api(code)
assert replaced_code == code.replace('torch.eye(3, 4)', assert replaced_code == code.replace('torch.eye(3, 4)',
'{}()(3, 4, mindspore.int32)'.format(expected_ms_api_name)) '{}()(3, 4, mindspore.int32)'.format(expected_ms_api_name))
assert new_start == start + len(expected_ms_api_name)
def test_convert_api_torch_dot_add_with_alpha_default(self): def test_convert_api_torch_dot_add_with_alpha_default(self):
"""Test convert_api function work ok when convert api torch.add""" """Test convert_api function work ok when convert api torch.add"""
code = """ code = "torch.add(input, value)"
torch.add(input, value)
"""
api_name = 'torch.add'
expected_ms_api_name = 'P.TensorAdd' expected_ms_api_name = 'P.TensorAdd'
start = code.find(api_name)
replaced_code, new_start = self.converter_ins.convert_api(code, start, api_name) replaced_code = self.converter_ins.convert_api(code)
assert replaced_code == code.replace('torch.add(input, value)', assert replaced_code == code.replace('torch.add(input, value)',
'{}()(input, value)'.format(expected_ms_api_name)) '{}()(input, value)'.format(expected_ms_api_name))
assert new_start == start + len(expected_ms_api_name)
def test_convert_api_torch_dot_add_with_alpha_not_default(self): def test_convert_api_torch_dot_add_with_alpha_not_default(self):
"""Test convert_api function work ok when convert api torch.add""" """Test convert_api function work ok when convert api torch.add"""
code = """ code = "torch.add(input, value, 3)"
torch.add(input, value, 3)
"""
api_name = 'torch.add'
expected_ms_api_name = 'P.TensorAdd' expected_ms_api_name = 'P.TensorAdd'
start = code.find(api_name)
replaced_code, new_start = self.converter_ins.convert_api(code, start, api_name) replaced_code = self.converter_ins.convert_api(code)
assert replaced_code == code.replace('torch.add(input, value, 3)', assert replaced_code == code.replace('torch.add(input, value, 3)',
'{}()(input, value*3)'.format(expected_ms_api_name)) '{}()(input, value*3)'.format(expected_ms_api_name))
assert new_start == start + len(expected_ms_api_name)
# test convert_api with F ops # test convert_api with F ops
def test_convert_api_f_normalize(self): def test_convert_api_f_normalize(self):
"""Test convert_api function work ok when convert api F.normalize""" """Test convert_api function work ok when convert api F.normalize"""
code = """ code = "F.normalize(input)"
input = torch.randn(2, 3, 5)
F.normalize(input)
"""
api_name = 'F.normalize'
start = code.find(api_name)
expected_ms_api_name = 'P.L2Normalize' expected_ms_api_name = 'P.L2Normalize'
replaced_code, new_start = self.converter_ins.convert_api(code, start, api_name) replaced_code = self.converter_ins.convert_api(code)
assert replaced_code == code.replace('F.normalize(input)', assert replaced_code == code.replace('F.normalize(input)',
'{}(1, 1e-12)(input)'.format(expected_ms_api_name)) '{}(1, 1e-12)(input)'.format(expected_ms_api_name))
assert new_start == start + len(expected_ms_api_name)
def test_convert_api_f_sigmoid(self): def test_convert_api_f_sigmoid(self):
"""Test convert_api function work ok when convert api F.sigmoid""" """Test convert_api function work ok when convert api F.sigmoid"""
code = """ code = "F.sigmoid(input)"
input = torch.randn(2, 3, 5)
F.sigmoid(input)
"""
api_name = 'F.sigmoid'
start = code.find(api_name)
expected_ms_api_name = 'P.Sigmoid' expected_ms_api_name = 'P.Sigmoid'
replaced_code, new_start = self.converter_ins.convert_api(code, start, api_name) replaced_code = self.converter_ins.convert_api(code)
assert replaced_code == code.replace('F.sigmoid(input)', assert replaced_code == code.replace('F.sigmoid(input)',
'{}()(input)'.format(expected_ms_api_name)) '{}()(input)'.format(expected_ms_api_name))
assert new_start == start + len(expected_ms_api_name)
# test convert_api with tensor dot ops # test convert_api with tensor dot ops
def test_convert_api_tensor_dot_repeat(self): def test_convert_api_tensor_dot_repeat(self):
"""Test convert_api function work ok when convert api .repeat""" """Test convert_api function work ok when convert api .repeat"""
code = """ code = "x.repeat(4, 2)"
x.repeat(4, 2)
"""
api_name = '.repeat'
start = code.find(api_name)
expected_ms_api_name = 'P.Tile' expected_ms_api_name = 'P.Tile'
replaced_code, new_start = self.converter_ins.convert_api(code, start, api_name) replaced_code = self.converter_ins.convert_api(code)
assert replaced_code == code.replace('x.repeat(4, 2)', assert replaced_code == code.replace('x.repeat(4, 2)',
'{}()(x, {})'.format(expected_ms_api_name, '(4, 2,)')) '{}()(x, {})'.format(expected_ms_api_name, '(4, 2,)'))
assert new_start == start + len(expected_ms_api_name)
def test_convert_api_tensor_dot_permute(self): def test_convert_api_tensor_dot_permute(self):
"""Test convert_api function work ok when convert api .permute""" """Test convert_api function work ok when convert api .permute"""
code = """ code = "x.permute(2, 0, 1)"
x.permute(2, 0, 1)
"""
api_name = '.permute'
start = code.find(api_name)
expected_ms_api_name = 'P.Transpose' expected_ms_api_name = 'P.Transpose'
replaced_code, new_start = self.converter_ins.convert_api(code, start, api_name) replaced_code = self.converter_ins.convert_api(code)
assert replaced_code == code.replace('x.permute(2, 0, 1)', assert replaced_code == code.replace('x.permute(2, 0, 1)',
'{}()(x, (2, 0, 1,))'.format(expected_ms_api_name)) '{}()(x, (2, 0, 1,))'.format(expected_ms_api_name))
assert new_start == start + len(expected_ms_api_name)
...@@ -15,7 +15,6 @@ ...@@ -15,7 +15,6 @@
"""Test forward_call module.""" """Test forward_call module."""
import ast import ast
import textwrap import textwrap
from unittest.mock import patch
from mindinsight.mindconverter.forward_call import ForwardCall from mindinsight.mindconverter.forward_call import ForwardCall
...@@ -50,12 +49,10 @@ class TestForwardCall: ...@@ -50,12 +49,10 @@ class TestForwardCall:
return out return out
""") """)
@patch.object(ForwardCall, 'process') def test_process(self):
def test_process(self, mock_process):
"""Test the function of visit ast tree to find out forward functions.""" """Test the function of visit ast tree to find out forward functions."""
mock_process.return_value = None ast_tree = ast.parse(self.source)
forward_call = ForwardCall("mock") forward_call = ForwardCall(ast_tree)
forward_call.visit(ast.parse(self.source))
expect_calls = ['TestNet.forward', expect_calls = ['TestNet.forward',
'TestNet.forward1', 'TestNet.forward1',
...@@ -70,6 +67,6 @@ class TestForwardCall: ...@@ -70,6 +67,6 @@ class TestForwardCall:
'TestNet.fc3', 'TestNet.fc3',
] ]
expect_calls.sort() expect_calls.sort()
real_calls = list(forward_call.calls) real_calls = list(forward_call.calls.keys())
real_calls.sort() real_calls.sort()
assert real_calls == expect_calls assert real_calls == expect_calls
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册