提交 5b288467 编写于 作者: Q quyongxiu1

add cli and report

fix review pros

add readme
上级 98516a84
### Introduction
MindConverter is a tool that converting PyTorch scripts to MindSpore scripts. With minial manual editing and the guidance from conversion reports, users may easily migrate their model from PyTorch framework to MindSpore.
### System Requirements
* PyTorch v1.5.0
* MindSpore v0.2.0
### Installation
This tool is part of MindInsight and accessible to users after installing MindInsight, no extra installation is needed.
### Commandline Usage
Set the model scripts directory as the PYTHONPATH environment variable first:
```buildoutcfg
export PYTHONPATH=<model scripts dir>
```
mindconverter commandline usage:
```buildoutcfg
mindconverter [-h] [--version] --in_file IN_FILE [--output OUTPUT]
[--report REPORT]
MindConverter CLI entry point (version: 0.2.0)
optional arguments:
-h, --help show this help message and exit
--version show program's version number and exit
--in_file IN_FILE Specify path for script file.
--output OUTPUT Specify path for converted script file directory. Default
is output directory in the current working directory.
--report REPORT Specify report directory. Default is the current working
directorys
```
Usage example:
```buildoutcfg
export PYTHONPATH=~/my_pt_proj/models
mindconverter --in_file lenet.py
```
Since the conversion is not 100% flawless, we encourage users to checkout the reports when fixing issues of the converted scripts.
### Unsupported Situation #1
Classes and functions that can't be converted:
* The use of shape, ndim and dtype member of torch.Tensor.
* torch.nn.AdaptiveXXXPoolXd and torch.nn.functional.adaptive_XXX_poolXd()
* torch.nn.functional.Dropout
* torch.unsqueeze() and torch.Tensor.unsqueeze()
* torch.chunk() and torch.Tensor.chunk()
### Unsupported Situation #2
Subclassing from the subclasses of nn.Module
e.g. (code snip from torchvision,models.mobilenet)
```python
from torch import nn
class ConvBNReLU(nn.Sequential):
def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1):
padding = (kernel_size - 1) // 2
super(ConvBNReLU, self).__init__(
nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False),
nn.BatchNorm2d(out_planes),
nn.ReLU6(inplace=True)
)
```
### Unsupported Situation #3
Unconventional import naming
e.g.
```python
import torch.nn as mm
```
\ No newline at end of file
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Command module."""
import os
import sys
import argparse
import mindinsight
from mindinsight.mindconverter.converter import main
class FileDirAction(argparse.Action):
"""File directory action class definition."""
@staticmethod
def check_path(parser, values, option_string=None):
"""
Check argument for file path.
Args:
parser (ArgumentParser): Passed-in argument parser.
values (object): Argument values with type depending on argument definition.
option_string (str): Optional string for specific argument name. Default: None.
"""
outfile = values
if outfile.startswith('~'):
outfile = os.path.realpath(os.path.expanduser(outfile))
if not outfile.startswith('/'):
outfile = os.path.realpath(os.path.join(os.getcwd(), outfile))
if os.path.exists(outfile) and not os.access(outfile, os.R_OK):
parser.error(f'{option_string} {outfile} not accessible')
return outfile
def __call__(self, parser, namespace, values, option_string=None):
"""
Inherited __call__ method from argparse.Action.
Args:
parser (ArgumentParser): Passed-in argument parser.
namespace (Namespace): Namespace object to hold arguments.
values (object): Argument values with type depending on argument definition.
option_string (str): Optional string for specific argument name. Default: None.
"""
outfile_dir = self.check_path(parser, values, option_string)
if os.path.isfile(outfile_dir):
parser.error(f'{option_string} {outfile_dir} is a file')
setattr(namespace, self.dest, outfile_dir)
class OutputDirAction(argparse.Action):
"""File directory action class definition."""
def __call__(self, parser, namespace, values, option_string=None):
"""
Inherited __call__ method from argparse.Action.
Args:
parser (ArgumentParser): Passed-in argument parser.
namespace (Namespace): Namespace object to hold arguments.
values (object): Argument values with type depending on argument definition.
option_string (str): Optional string for specific argument name. Default: None.
"""
output = values
if output.startswith('~'):
output = os.path.realpath(os.path.expanduser(output))
if not output.startswith('/'):
output = os.path.realpath(os.path.join(os.getcwd(), output))
if os.path.exists(output):
if not os.access(output, os.R_OK):
parser.error(f'{option_string} {output} not accessible')
if os.path.isfile(output):
parser.error(f'{option_string} {output} is a file')
setattr(namespace, self.dest, output)
class InFileAction(argparse.Action):
"""Input File action class definition."""
def __call__(self, parser, namespace, values, option_string=None):
"""
Inherited __call__ method from argparse.Action.
Args:
parser (ArgumentParser): Passed-in argument parser.
namespace (Namespace): Namespace object to hold arguments.
values (object): Argument values with type depending on argument definition.
option_string (str): Optional string for specific argument name. Default: None.
"""
outfile_dir = FileDirAction.check_path(parser, values, option_string)
if not os.path.exists(outfile_dir):
parser.error(f'{option_string} {outfile_dir} not exists')
if not os.path.isfile(outfile_dir):
parser.error(f'{option_string} {outfile_dir} is not a file')
setattr(namespace, self.dest, outfile_dir)
class LogFileAction(argparse.Action):
"""Log file action class definition."""
def __call__(self, parser, namespace, values, option_string=None):
"""
Inherited __call__ method from FileDirAction.
Args:
parser (ArgumentParser): Passed-in argument parser.
namespace (Namespace): Namespace object to hold arguments.
values (object): Argument values with type depending on argument definition.
option_string (str): Optional string for specific argument name. Default: None.
"""
outfile_dir = FileDirAction.check_path(parser, values, option_string)
if os.path.exists(outfile_dir) and not os.path.isdir(outfile_dir):
parser.error(f'{option_string} {outfile_dir} is not a directory')
setattr(namespace, self.dest, outfile_dir)
def cli_entry():
"""Entry point for mindconverter CLI."""
permissions = os.R_OK | os.W_OK | os.X_OK
os.umask(permissions << 3 | permissions)
parser = argparse.ArgumentParser(
prog='mindconverter',
description='MindConverter CLI entry point (version: {})'.format(mindinsight.__version__))
parser.add_argument(
'--version',
action='version',
version='%(prog)s ({})'.format(mindinsight.__version__))
parser.add_argument(
'--in_file',
type=str,
action=InFileAction,
required=True,
help="""
Specify path for script file.
""")
parser.add_argument(
'--output',
type=str,
action=OutputDirAction,
default=os.path.join(os.getcwd(), 'output'),
help="""
Specify path for converted script file directory.
Default is output directory in the current working directory.
""")
parser.add_argument(
'--report',
type=str,
action=LogFileAction,
default=os.getcwd(),
help="""
Specify report directory. Default is the current working directory.
""")
argv = sys.argv[1:]
if not argv:
argv = ['-h']
args = parser.parse_args(argv)
else:
args = parser.parse_args()
mode = permissions << 6
os.makedirs(args.output, mode=mode, exist_ok=True)
os.makedirs(args.report, mode=mode, exist_ok=True)
_run(args.in_file, args.output, '', args.report)
def _run(in_files, out_dir, in_module, report):
"""
Run converter command.
Args:
in_files (str): The file path or directory to convert.
out_dir (str): The output directory to save converted file.
in_module (str): The module name to convert.
report (str): The report file path.
"""
files_config = {
'root_path': in_files if in_files else '',
'in_files': [],
'outfile_dir': out_dir,
'report_dir': report,
'in_module': in_module
}
if os.path.isfile(in_files):
files_config['root_path'] = os.path.dirname(in_files)
files_config['in_files'] = [in_files]
else:
for root_dir, _, files in os.walk(in_files):
for file in files:
files_config['in_files'].append(os.path.join(root_dir, file))
main(files_config)
...@@ -15,4 +15,4 @@ ...@@ -15,4 +15,4 @@
"""Create a logger.""" """Create a logger."""
from mindinsight.utils.log import setup_logger from mindinsight.utils.log import setup_logger
logger = setup_logger("mindconverter", "mindconverter") logger = setup_logger("mindconverter", "mindconverter", console=False)
...@@ -67,7 +67,7 @@ class APIPt: ...@@ -67,7 +67,7 @@ class APIPt:
Raises: Raises:
ValueError: If can not use ast to parse or the required parse node not type of ast.Call, ValueError: If can not use ast to parse or the required parse node not type of ast.Call,
or the given args_str not valid. or the given args_str not valid.
""" """
# expr is REQUIRED to meet (**) format # expr is REQUIRED to meet (**) format
if not (len(args_str) >= 2 and args_str[0] == "(" and args_str[-1] == ")"): if not (len(args_str) >= 2 and args_str[0] == "(" and args_str[-1] == ")"):
...@@ -222,6 +222,7 @@ class MappingHelper: ...@@ -222,6 +222,7 @@ class MappingHelper:
def convert(self, call_name_pt: str, args_str_pt: str): def convert(self, call_name_pt: str, args_str_pt: str):
""" """
Convert code sentence to MindSpore code sentence. Convert code sentence to MindSpore code sentence.
Args: Args:
call_name_pt (str): str of the call function, etc. call_name_pt (str): str of the call function, etc.
args_str_pt (str): str of args for function, which starts with '(' and end with ')'. args_str_pt (str): str of args for function, which starts with '(' and end with ')'.
...@@ -336,11 +337,12 @@ def gen_explicit_map_one_delta(params_pt, args_pt, k_ms, k_pt): ...@@ -336,11 +337,12 @@ def gen_explicit_map_one_delta(params_pt, args_pt, k_ms, k_pt):
def load_json_file(file_path): def load_json_file(file_path):
""" """
Load data from given json file path. Load data from given json file path.
Args: Args:
file_path (str): The file to load json data from. file_path (str): The file to load json data from.
Returns: Returns:
list, the list data stored in file_path. list(str), the list data stored in file_path.
""" """
with open(file_path, 'r', encoding='utf-8') as file: with open(file_path, 'r', encoding='utf-8') as file:
info = json.loads(file.read()) info = json.loads(file.read())
......
...@@ -12,10 +12,10 @@ ...@@ -12,10 +12,10 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""main module""" """converter module"""
import inspect
import copy import copy
import importlib import importlib
import inspect
import os import os
import stat import stat
...@@ -24,332 +24,438 @@ from mindinsight.mindconverter.config import NN_LIST ...@@ -24,332 +24,438 @@ from mindinsight.mindconverter.config import NN_LIST
from mindinsight.mindconverter.config import ALL_TORCH_APIS from mindinsight.mindconverter.config import ALL_TORCH_APIS
from mindinsight.mindconverter.config import ALL_2P_LIST from mindinsight.mindconverter.config import ALL_2P_LIST
from mindinsight.mindconverter.config import UNSUPPORTED_WARN_INFOS from mindinsight.mindconverter.config import UNSUPPORTED_WARN_INFOS
from mindinsight.mindconverter.config import ALL_UNSUPPORTED
from mindinsight.mindconverter.common.log import logger from mindinsight.mindconverter.common.log import logger
from mindinsight.mindconverter.forward_call import ForwardCall
def is_local_defined(obj, member): class Converter:
""" """Convert class"""
Check if obj and member are both defined in the same source file.
Args:
obj (Union[object, module]): A module or a class.
member (func): A function of obj.
Returns:
bool, True or False.
"""
srcfile = inspect.getsourcefile(obj)
return inspect.getsourcefile(member) == srcfile
def is_valid_module(obj, member):
"""
Check if obj and member defined in same source file and member is inherited from torch.nn.Module.
Args:
obj (Union[object, module]): A module or a class.
member (func): A function.
Returns:
bool, True or False.
"""
return inspect.isclass(member) and (member.__base__.__name__ == 'Module') and is_local_defined(obj, member)
def is_valid_function(obj, member):
"""
Check if member is function and defined in the file same as obj.
Args:
obj (Union[object, module]: The obj.
member (func): The func.
Returns:
bool, True or False.
"""
return inspect.isfunction(member) and is_local_defined(obj, member)
def find_left_parentheses(string, right):
"""
Find index of the first left parenthesis.
Args:
string (str): A line of code.
right (int): Max index of string, same as `len(string) -1`.
Returns:
int, index of the first parenthesis.
Raises:
ValueError: If line of code doesn't contain any pair of `()` or `(` and `)` are not paired.
"""
if string[right] != ')':
raise ValueError('code [{}] at index {} not ")".'.format(string, right))
stack = []
for i in range(right, -1, -1):
if string[i] == ')':
stack.append(')')
elif string[i] == '(':
stack.pop()
if not stack:
return i
raise ValueError("{} should contain ()".format(string))
def find_right_parentheses(string, left):
"""
Find first index of right parenthesis which make all left parenthesis make sense.
Args:
string (str): A line of code.
left (int): Start index of string to find from.
Returns:
int, index of the found right parenthesis.
Raises:
ValueError: If line of code doesn't contain any pair of `()` or `(` and `)` are not paired.
"""
stack = []
for i in range(left, len(string)):
if string[i] == '(':
stack.append('(')
elif string[i] == ')':
stack.pop()
if not stack:
return i
raise ValueError("{} should contain ()".format(string))
def get_call_name(code, end):
"""
Traverse code in a reversed function from index end and get the call name and start index of the call name, if call
name not found, return a null character string and -1
Args: convert_info = ''
code (str): The str of code to find from. flags = os.O_WRONLY | os.O_CREAT | os.O_EXCL
end (int): Start index to find. modes = stat.S_IWUSR | stat.S_IRUSR
Returns: @staticmethod
str, founded api name if found, else a null character string. def is_local_defined(obj, member):
int, start index of founded api name, -1 if api name not found """
""" Check if obj and member are both defined in the same source file.
stack = []
for i in range(end - 1, -1, -1): Args:
if code[i] in ["(", "[", "{"]: obj (Union[object, module]): A module or a class.
if stack: member (func): A function of obj.
Returns:
bool, True or False.
"""
srcfile = inspect.getsourcefile(obj)
return inspect.getsourcefile(member) == srcfile
@classmethod
def is_valid_module(cls, obj, member):
"""
Check if obj and member defined in same source file and member is inherited from torch.nn.Module.
Args:
obj (Union[object, module]): A module or a class.
member (func): A function.
Returns:
bool, True or False.
"""
if inspect.isclass(member):
is_subclass = member.__base__.__name__ in ['Module',
'Sequential',
'ModuleList',
'ModuleDict',
'ParameterList',
'ParameterDict']
return is_subclass and cls.is_local_defined(obj, member)
return False
@classmethod
def is_valid_function(cls, obj, member):
"""
Check if member is function and defined in the file same as obj.
Args:
obj (Union[object, module]: The obj.
member (func): The func.
Returns:
bool, True or False.
"""
return inspect.isfunction(member) and cls.is_local_defined(obj, member)
@staticmethod
def find_left_parentheses(string, right):
"""
Find index of the first left parenthesis.
Args:
string (str): A line of code.
right (int): The right index for string to find from.
Returns:
int, index of the first parenthesis.
Raises:
ValueError: If line of code doesn't contain any pair of `()` or `(` and `)` are not paired.
"""
if string[right] != ')':
raise ValueError('code [{}] at index {} not ")".'.format(string, right))
stack = []
for i in range(right, -1, -1):
if string[i] == ')':
stack.append(')')
elif string[i] == '(':
stack.pop()
if not stack:
return i
raise ValueError("{} should contain ()".format(string))
@staticmethod
def find_right_parentheses(string, left):
"""
Find first index of right parenthesis which make all left parenthesis make sense.
Args:
string (str): A line of code.
left (int): Start index of string to find from.
Returns:
int, index of the found right parenthesis.
Raises:
ValueError: If line of code doesn't contain any pair of `()` or `(` and `)` are not paired.
"""
stack = []
for i in range(left, len(string)):
if string[i] == '(':
stack.append('(')
elif string[i] == ')':
stack.pop() stack.pop()
else: if not stack:
return i
raise ValueError("{} should contain ()".format(string))
@staticmethod
def get_call_name(code, end):
"""
Traverse code in a reversed function from index end and get the call name and start index of the call name,
if call name not found, return a null character string and -1
Args:
code (str): The str of code to find from.
end (int): Start index to find.
Returns:
tuple(str, int), one is founded api name if found, else a null character string, the other is start index
of founded api name, -1 if api name not found
"""
stack = []
for i in range(end - 1, -1, -1):
if code[i] in ["(", "[", "{"]:
if stack:
stack.pop()
else:
return code[i + 1:end], i + 1
elif code[i] in [")", "]", "}"]:
stack.append(code[i])
elif stack:
continue
elif not (code[i].isalpha() or code[i].isdigit() or code[i] == '_' or code[i] == '.'):
return code[i + 1:end], i + 1 return code[i + 1:end], i + 1
elif code[i] in [")", "]", "}"]: return "", -1
stack.append(code[i])
elif stack: def convert_api(self, code, start, api_name=""):
continue """
elif not (code[i].isalpha() or code[i].isdigit() or code[i] == '_' or code[i] == '.'): Convert api_name in code to MindSpore api with start as a start index, if api_name is a python api,
return code[i + 1:end], i + 1 code will not convert.
return "", -1
Args:
code (str): The str code to convert.
def convert_api(code, start, api_name=""): start (int): The index of code to start convert from.
""" api_name (str): The api name to convert.
Convert api_name in code to MindSpore api with start as a start index, if api_name is a python api, code will not
convert. Returns:
str, the converted code.
Args: int, index of converted api_name in code.
code (str): The str code to convert. """
start (int): The index of code to start convert from. # handle format like .shape(
api_name (str): The api name to convert. if api_name.startswith('.'):
call_name, new_start = self.get_call_name(code, start)
Returns: if start == -1 or call_name == "self":
str, the converted code. return code, start + 1
int, index of converted api_name in code. else:
call_name = api_name
""" new_start = start
# handle format like .shape(
if api_name.startswith('.'): # find full api expected to be converted. eg:expr="nn.Conv2d(1,2,3)" args_str="(1,2,3)"
call_name, new_start = get_call_name(code, start) left = code.find("(", start)
if start == -1 or call_name == "self": if left == -1:
return code, start + 1 raise ValueError('"(" not found, {} should work with "("'.format(call_name))
else: right = self.find_right_parentheses(code, left)
call_name = api_name end = right
new_start = start expr = code[start:end + 1]
args_str = code[left:right + 1]
# find full api expected to be converted. eg:expr="nn.Conv2d(1,2,3)" args_str="(1,2,3)"
left = code.find("(", start) map_helper = ALL_MAPPING[api_name]
if left == -1: new_expr = map_helper.convert(call_name, args_str)
raise ValueError('"(" not found, {} should work with "("'.format(call_name)) next_newline = code.find("\n", end + 1)
right = find_right_parentheses(code, left) fill_num = (expr.count("\n") - new_expr.count("\n"))
end = right if next_newline != -1:
expr = code[start:end + 1] code = code[:new_start] + new_expr + code[end + 1:next_newline] + ("\n" * fill_num) + code[next_newline:]
args_str = code[left:right + 1] else:
code = code[:new_start] + new_expr + ")" + ("\n" * fill_num) + code[end + 2:]
map_helper = ALL_MAPPING[api_name]
new_expr = map_helper.convert(call_name, args_str) return code, start + len(map_helper.ms_api.name)
next_newline = code.find("\n", end + 1)
fill_num = (expr.count("\n") - new_expr.count("\n")) @staticmethod
if next_newline != -1: def find_api(code, i, is_forward):
code = code[:new_start] + new_expr + code[end + 1:next_newline] + ("\n" * fill_num) + code[next_newline:] """
else: Find api name from code with a start index i, check api name ok with a is_forward condition.
code = code[:new_start] + new_expr + ")" + ("\n" * fill_num) + code[end + 2:]
return code, start + len(map_helper.ms_api.name) Args:
code (str): The code from which to find api name.
i (int): The start index to find.
def find_api(code, i, is_forward): is_forward (bool): Check if the found api name ok.
Returns:
str, api name if find api name and check ok with is_forward condition, else a null character string.
"""
if code[i:].startswith("nn.") \
or code[i:].startswith("F.") \
or code[i:].startswith("torch.") \
or code[i:].startswith('.'):
j = code.find('(', i)
if j != -1 and code[i:j] in ALL_TORCH_APIS:
api_name = code[i:j]
if (not is_forward and api_name in NN_LIST) or (is_forward and api_name in ALL_2P_LIST):
return api_name
return ""
def convert_function(self, fun_name, fun, is_forward):
"""
Convert a PyTorch function into MindSpore function.
Args:
fun_name (str): The str of function name.
fun (func): The function to convert.
is_forward (bool): If the function is defined in forward function in nn.Module in torch.
Returns:
dict, old code and converted code map if convert happens, else {}.
"""
_, line_no = inspect.getsourcelines(fun)
logger.info("Line %3d: start converting function %s()", line_no, fun_name)
code = inspect.getsource(fun)
code_saved = copy.copy(code)
i = 0
while i < len(code):
api_name = self.find_api(code, i, is_forward)
if api_name:
line_no1 = line_no + code[:i].count('\n')
if api_name in ALL_MAPPING:
logger.info("Line %3d start converting API: %s", line_no1, api_name)
code, i = self.convert_api(code, i, api_name)
self.convert_info += "[Convert][Line{:3d}] {} is converted.\n".format(line_no1, api_name)
continue
if api_name in ALL_UNSUPPORTED:
warn_info = ". " + UNSUPPORTED_WARN_INFOS[api_name] if api_name in UNSUPPORTED_WARN_INFOS else ""
logger.warning("Line %3d: found unsupported API: %s%s", line_no1, api_name, warn_info)
self.convert_info += "[Unconvert][Line{:3d}] {} didn't convert{}\n".format(line_no1,
api_name, warn_info)
i += 1
return {code_saved: code} if code_saved != code else {}
@staticmethod
def judge_forward(name, forward_list):
"""
Check if function is a forward function.
Args:
name (str): The function name.
forward_list (set): A set of forward function.
Returns:
bool, True or False
"""
is_forward = name in forward_list or name.split(".")[-1] == "forward"
if is_forward:
logger.debug("%s is a forward function", name)
return is_forward
def convert_module(self, module_name, module, forward_list):
"""
Convert a PyTorch module code into MindSpore module code.
Args:
module_name (str): The module's name.
module (module): The module to convert.
forward_list (set): A set of forward function.
Returns:
dict, map of old code and converted code.
"""
_, line_no = inspect.getsourcelines(module)
logger.info("Line {:3d}: start converting nn.Module {}".format(line_no, module_name))
mapped = {}
for name, member in inspect.getmembers(module):
if self.is_valid_function(module, member):
is_forward = self.judge_forward("{}.{}".format(module_name, name), forward_list)
mapped.update(self.convert_function(name, member, is_forward))
return mapped
def get_mapping(self, import_mod, forward_list):
"""
Convert code of a module and get mapping of old code and convert code.
Args:
import_mod (module): The module to convert.
forward_list (set): A set of forward function.
Returns:
dict, mapping for old code and converted code of the module
"""
mapping = {}
tasks = []
for name, member in inspect.getmembers(import_mod):
if self.is_valid_module(import_mod, member):
_, line_no = inspect.getsourcelines(member)
tasks.append((line_no, self.convert_module, (name, member, forward_list)))
elif self.is_valid_function(import_mod, member):
_, line_no = inspect.getsourcelines(member)
is_forward = self.judge_forward("{}.{}".format(import_mod, name), forward_list)
tasks.append((line_no, self.convert_function, (name, member, is_forward)))
tasks.sort()
for _, convert_fun, args in tasks:
mapping.update(convert_fun(*args))
return mapping
def convert(self, import_name, output_dir, report_dir):
"""
Convert a module's code, code converted will be save in output_dir, and a report will be save in report_dir.
Args:
import_name (str): The module from which to import the module to convert.
output_dir (str): The path to save converted file.
report_dir (str): The path to save report file.
"""
logger.info("Start converting %s", import_name)
self.convert_info += '[Start Convert]\nThe module is {}\n'.format(import_name)
import_mod = importlib.import_module(import_name)
srcfile = inspect.getsourcefile(import_mod)
logger.info("Script file is %s", srcfile)
forward_list = set(ForwardCall(srcfile).calls)
logger.debug("Forward_list: %s", forward_list)
# replace python function under nn.Module
mapping = self.get_mapping(import_mod, forward_list)
code = inspect.getsource(import_mod)
for key, value in mapping.items():
code = code.replace(key, value)
code = 'import mindspore.ops.operations as P\n' + code
code = 'import mindspore.nn as nn\n' + code
code = 'import mindspore\n' + code
self.convert_info += '||[Import Add] Add follow import sentences:\n'
self.convert_info += 'import mindspore.ops.operations as P\n'
self.convert_info += 'import mindspore.nn as nn\n'
self.convert_info += 'import mindspore\n\n'
code = code.replace('import torch', '# import torch')
code = code.replace('from torch', '# from torch')
code = code.replace('(nn.Module):', '(nn.Cell):')
code = code.replace('forward(', 'construct(')
code = code.replace('nn.Linear', 'nn.Dense')
code = code.replace('(nn.Sequential)', '(nn.SequentialCell)')
code = code.replace('nn.init.', 'pass # nn.init.')
self.convert_info += '||[Import Annotated] Annotated follow import sentences:\n'
self.convert_info += 'import sentence on torch as follows are annotated:\n'
self.convert_info += 'import torch\n'
self.convert_info += 'from torch ...\n'
self.convert_info += '||[Explicit Convert] Module or function are explicitly converted as follows:\n'
self.convert_info += '[nn.Module] is converted to [nn.Cell]\n'
self.convert_info += '[forward] is converted to [construct]\n'
self.convert_info += '[nn.Linear] is converted to [nn.Dense]\n'
self.convert_info += '[nn.Sequential] is converted to [nn.SequentialCell]\n'
self.convert_info += '[nn.init] is not converted and annotated\n'
self.convert_info += '[Convert over]'
dest_file = os.path.join(output_dir, os.path.basename(srcfile))
with os.fdopen(os.open(dest_file, self.flags, self.modes), 'w') as file:
file.write(code)
logger.info("Convert success. Result is wrote to %s.", dest_file)
dest_report_file = os.path.join(report_dir,
'_'.join(os.path.basename(srcfile).split('.')[:-1]) + '_report.txt')
with os.fdopen(os.open(dest_report_file, self.flags, self.modes), 'a') as file:
file.write(self.convert_info)
logger.info("Convert report is saved in %s", dest_report_file)
def _get_name_ext(file):
""" """
Find api name from code with a start index i, check api name ok with a is_forward condition. Split a file name in name and extension.
Args:
code (str): The code from which to find api name.
i (int): The start index to find.
is_forward (bool): Check if the found api name ok.
Returns:
str, api name if find api name and check ok with is_forward condition, else a null character string.
"""
if code[i:].startswith("nn.") \
or code[i:].startswith("F.") \
or code[i:].startswith("torch.") \
or code[i:].startswith('.'):
j = code.find('(', i)
if j != -1 and code[i:j] in ALL_TORCH_APIS:
api_name = code[i:j]
if (not is_forward and api_name in NN_LIST) or (is_forward and api_name in ALL_2P_LIST):
return api_name
return ""
def convert_function(fun_name, fun, is_forward):
"""
Convert a PyTorch function into MindSpore function.
Args: Args:
fun_name (str): The str of function name. file (str): Full file path.
fun (func): The function to convert.
is_forward (bool): If the function is defined in forward function in nn.Module in torch.
Returns: Returns:
dict, old code and converted code map if convert happens, else {}. tuple (str, str), name and extension.
""" """
_, line_no = inspect.getsourcelines(fun) _, name = os.path.split(file)
logger.info("Line %3d: start converting function %s()", line_no, fun_name) return os.path.splitext(name)
code = inspect.getsource(fun)
code_saved = copy.copy(code)
i = 0
while i < len(code):
api_name = find_api(code, i, is_forward)
if api_name:
line_no1 = line_no + code[:i].count('\n')
if api_name in ALL_MAPPING:
logger.info("Line %3d start converting API: %s", line_no1, api_name)
code, i = convert_api(code, i, api_name)
continue
warn_info = ". " + UNSUPPORTED_WARN_INFOS[api_name] if api_name in UNSUPPORTED_WARN_INFOS else ""
logger.warning("Line %3d: found unsupported API: %s%s", line_no1, api_name, warn_info)
i += 1
return {code_saved: code} if code_saved != code else {}
def judge_forward(name, forward_list): def _path_split(file):
""" """
Check if function is a forward function. Split a path in head and tail.
Args: Args:
name (str): The function name. file (str): The file path.
forward_list (set): A set of forward function.
Returns: Returns:
bool, True or False list[str], list of file tail
"""
is_forward = name in forward_list or name.split(".")[-1] == "forward"
if is_forward:
logger.debug("%s is a forward function", name)
return is_forward
def convert_module(module_name, module, forward_list):
""" """
Convert a PyTorch module code into MindSpore module code. file_dir, name = os.path.split(file)
Args: if file_dir:
module_name (str): The module's name. sep = file[len(file_dir)-1]
module (module): The module to convert. if file_dir.startswith(sep):
forward_list (set): A set of forward function. return file.split(sep)[1:]
Returns: return file.split(sep)
dict, map of old code and converted code. return [name]
"""
_, line_no = inspect.getsourcelines(module)
logger.info("Line {:3d}: start converting nn.Module {}".format(line_no, module_name))
mapped = {}
for name, member in inspect.getmembers(module):
if is_valid_function(module, member):
is_forward = judge_forward("{}.{}".format(module_name, name), forward_list)
mapped.update(convert_function(name, member, is_forward))
return mapped
def get_mapping(import_mod, forward_list): def main(files_config):
""" """
Convert code of a module and get mapping of old code and convert code. The entrance for converter, script files will be converted.
Args:
import_mod (module): The module to convert.
forward_list (set): A set of forward function.
Returns:
dict, mapping for old code and converted code of the module
"""
mapping = {}
tasks = []
for name, member in inspect.getmembers(import_mod):
if is_valid_module(import_mod, member):
_, line_no = inspect.getsourcelines(member)
tasks.append((line_no, convert_module, (name, member, forward_list)))
elif is_valid_function(import_mod, member):
_, line_no = inspect.getsourcelines(member)
is_forward = judge_forward("{}.{}".format(import_mod, name), forward_list)
tasks.append((line_no, convert_function, (name, member, is_forward)))
tasks.sort()
for _, convert_fun, args in tasks:
mapping.update(convert_fun(*args))
return mapping
def convert(import_name, nn_module):
"""
The entrance for convert a module's code, code converted will be write to file called out.py.
Args: Args:
import_name (str): The module from which to import the module to convert. files_config (dict): The config of files which to convert.
nn_module (str): Name of the module to convert.
""" """
logger.info("Start converting %s.%s", import_name, nn_module) convert_ins = Converter()
import_mod = importlib.import_module(import_name) root_path = files_config['root_path']
in_files = files_config['in_files']
forward_list = set() for in_file in in_files:
in_file_split = _path_split(in_file[len(root_path):])
logger.debug("Forward_list: %s", forward_list) in_file_split[-1], _ = _get_name_ext(in_file_split[-1])
module_name = '.'.join(in_file_split)
# replace python function under nn.Modlue convert_ins.convert(module_name, files_config['outfile_dir'], files_config['report_dir'])
mapping = get_mapping(import_mod, forward_list)
in_module = files_config['in_module']
code = inspect.getsource(import_mod) if in_module:
for key, value in mapping.items(): convert_ins.convert(in_module, files_config['outfile_dir'], files_config['report_dir'])
code = code.replace(key, value)
code = 'import mindspore.ops.operations as P\n' + code
code = 'import mindspore.nn as nn\n' + code
code = 'import mindspore\n' + code
code = code.replace('import torch', '# import torch')
code = code.replace('from torch', '# from torch')
code = code.replace('(nn.Module):', '(nn.Cell):')
code = code.replace('forward(', 'construct(')
code = code.replace('nn.Linear', 'nn.Dense')
code = code.replace('(nn.Sequential)', '(nn.SequentialCell)')
code = code.replace('nn.init.', 'pass # nn.init.')
flags = os.O_WRONLY | os.O_CREAT | os.O_EXCL
modes = stat.S_IWUSR | stat.S_IRUSR
with os.fdopen(os.open('out.py', flags, modes), 'w') as file:
file.write(code)
logger.info("Convert success. Result is wrote to out.py\n")
if __name__ == '__main__':
convert('torchvision.models.resnet', 'resnet18')
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Find out forward functions of script file"""
import ast
import os
class ForwardCall(ast.NodeVisitor):
"""
AST visitor that processes forward calls.
Find the sub functions called by the forward function in the script file.
"""
def __init__(self, filename):
self.filename = filename
self.module_name = os.path.basename(filename).replace('.py', '')
self.name_stack = []
self.forward_stack = []
self.calls = []
self.process()
def process(self):
"""Parse the python source file to find the forward functions."""
with open(self.filename, 'rt', encoding='utf-8') as file:
content = file.read()
self.visit(ast.parse(content, self.filename))
def get_current_namespace(self):
"""Get the namespace when visit the AST node"""
namespace = '.'.join(self.name_stack)
return namespace
@classmethod
def get_ast_node_name(cls, node):
"""Get AST node name."""
if isinstance(node, ast.Attribute):
return f'{cls.get_ast_node_name(node.value)}.{node.attr}'
if isinstance(node, ast.Name):
return node.id
return node
def visit_ClassDef(self, node):
"""Callback function when visit AST tree"""
self.name_stack.append(node.name)
self.generic_visit(node)
self.name_stack.pop()
def visit_FunctionDef(self, node):
"""Callback function when visit AST tree"""
func_name = f'{self.get_current_namespace()}.{node.name}'
is_in_chain = func_name in self.calls or node.name == 'forward'
if is_in_chain:
self.forward_stack.append(func_name)
if node.name == 'forward':
self.calls.append(func_name)
self.generic_visit(node)
if is_in_chain:
self.forward_stack.pop()
def visit_Call(self, node):
"""Callback function when visit AST tree"""
for arg in node.args:
self.visit(arg)
for kw in node.keywords:
self.visit(kw.value)
func_name = self.get_ast_node_name(node.func)
if isinstance(node.func, ast.Name):
if func_name not in ['super', 'str', 'repr']:
if self.forward_stack:
self.calls.append(func_name)
self.visit(node.func)
else:
if self.forward_stack:
if 'self' in func_name:
self.calls.append(f'{self.get_current_namespace()}.{func_name.split(".")[-1]}')
else:
self.calls.append(func_name)
self.visit(node.func)
...@@ -194,6 +194,7 @@ if __name__ == '__main__': ...@@ -194,6 +194,7 @@ if __name__ == '__main__':
entry_points={ entry_points={
'console_scripts': [ 'console_scripts': [
'mindinsight=mindinsight.utils.command:main', 'mindinsight=mindinsight.utils.command:main',
'mindconverter=mindinsight.mindconverter.cli:cli_entry',
], ],
}, },
python_requires='>=3.7', python_requires='>=3.7',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册