提交 5b288467 编写于 作者: Q quyongxiu1

add cli and report

fix review pros

add readme
上级 98516a84
### Introduction
MindConverter is a tool that converting PyTorch scripts to MindSpore scripts. With minial manual editing and the guidance from conversion reports, users may easily migrate their model from PyTorch framework to MindSpore.
### System Requirements
* PyTorch v1.5.0
* MindSpore v0.2.0
### Installation
This tool is part of MindInsight and accessible to users after installing MindInsight, no extra installation is needed.
### Commandline Usage
Set the model scripts directory as the PYTHONPATH environment variable first:
```buildoutcfg
export PYTHONPATH=<model scripts dir>
```
mindconverter commandline usage:
```buildoutcfg
mindconverter [-h] [--version] --in_file IN_FILE [--output OUTPUT]
[--report REPORT]
MindConverter CLI entry point (version: 0.2.0)
optional arguments:
-h, --help show this help message and exit
--version show program's version number and exit
--in_file IN_FILE Specify path for script file.
--output OUTPUT Specify path for converted script file directory. Default
is output directory in the current working directory.
--report REPORT Specify report directory. Default is the current working
directorys
```
Usage example:
```buildoutcfg
export PYTHONPATH=~/my_pt_proj/models
mindconverter --in_file lenet.py
```
Since the conversion is not 100% flawless, we encourage users to checkout the reports when fixing issues of the converted scripts.
### Unsupported Situation #1
Classes and functions that can't be converted:
* The use of shape, ndim and dtype member of torch.Tensor.
* torch.nn.AdaptiveXXXPoolXd and torch.nn.functional.adaptive_XXX_poolXd()
* torch.nn.functional.Dropout
* torch.unsqueeze() and torch.Tensor.unsqueeze()
* torch.chunk() and torch.Tensor.chunk()
### Unsupported Situation #2
Subclassing from the subclasses of nn.Module
e.g. (code snip from torchvision,models.mobilenet)
```python
from torch import nn
class ConvBNReLU(nn.Sequential):
def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1):
padding = (kernel_size - 1) // 2
super(ConvBNReLU, self).__init__(
nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False),
nn.BatchNorm2d(out_planes),
nn.ReLU6(inplace=True)
)
```
### Unsupported Situation #3
Unconventional import naming
e.g.
```python
import torch.nn as mm
```
\ No newline at end of file
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Command module."""
import os
import sys
import argparse
import mindinsight
from mindinsight.mindconverter.converter import main
class FileDirAction(argparse.Action):
"""File directory action class definition."""
@staticmethod
def check_path(parser, values, option_string=None):
"""
Check argument for file path.
Args:
parser (ArgumentParser): Passed-in argument parser.
values (object): Argument values with type depending on argument definition.
option_string (str): Optional string for specific argument name. Default: None.
"""
outfile = values
if outfile.startswith('~'):
outfile = os.path.realpath(os.path.expanduser(outfile))
if not outfile.startswith('/'):
outfile = os.path.realpath(os.path.join(os.getcwd(), outfile))
if os.path.exists(outfile) and not os.access(outfile, os.R_OK):
parser.error(f'{option_string} {outfile} not accessible')
return outfile
def __call__(self, parser, namespace, values, option_string=None):
"""
Inherited __call__ method from argparse.Action.
Args:
parser (ArgumentParser): Passed-in argument parser.
namespace (Namespace): Namespace object to hold arguments.
values (object): Argument values with type depending on argument definition.
option_string (str): Optional string for specific argument name. Default: None.
"""
outfile_dir = self.check_path(parser, values, option_string)
if os.path.isfile(outfile_dir):
parser.error(f'{option_string} {outfile_dir} is a file')
setattr(namespace, self.dest, outfile_dir)
class OutputDirAction(argparse.Action):
"""File directory action class definition."""
def __call__(self, parser, namespace, values, option_string=None):
"""
Inherited __call__ method from argparse.Action.
Args:
parser (ArgumentParser): Passed-in argument parser.
namespace (Namespace): Namespace object to hold arguments.
values (object): Argument values with type depending on argument definition.
option_string (str): Optional string for specific argument name. Default: None.
"""
output = values
if output.startswith('~'):
output = os.path.realpath(os.path.expanduser(output))
if not output.startswith('/'):
output = os.path.realpath(os.path.join(os.getcwd(), output))
if os.path.exists(output):
if not os.access(output, os.R_OK):
parser.error(f'{option_string} {output} not accessible')
if os.path.isfile(output):
parser.error(f'{option_string} {output} is a file')
setattr(namespace, self.dest, output)
class InFileAction(argparse.Action):
"""Input File action class definition."""
def __call__(self, parser, namespace, values, option_string=None):
"""
Inherited __call__ method from argparse.Action.
Args:
parser (ArgumentParser): Passed-in argument parser.
namespace (Namespace): Namespace object to hold arguments.
values (object): Argument values with type depending on argument definition.
option_string (str): Optional string for specific argument name. Default: None.
"""
outfile_dir = FileDirAction.check_path(parser, values, option_string)
if not os.path.exists(outfile_dir):
parser.error(f'{option_string} {outfile_dir} not exists')
if not os.path.isfile(outfile_dir):
parser.error(f'{option_string} {outfile_dir} is not a file')
setattr(namespace, self.dest, outfile_dir)
class LogFileAction(argparse.Action):
"""Log file action class definition."""
def __call__(self, parser, namespace, values, option_string=None):
"""
Inherited __call__ method from FileDirAction.
Args:
parser (ArgumentParser): Passed-in argument parser.
namespace (Namespace): Namespace object to hold arguments.
values (object): Argument values with type depending on argument definition.
option_string (str): Optional string for specific argument name. Default: None.
"""
outfile_dir = FileDirAction.check_path(parser, values, option_string)
if os.path.exists(outfile_dir) and not os.path.isdir(outfile_dir):
parser.error(f'{option_string} {outfile_dir} is not a directory')
setattr(namespace, self.dest, outfile_dir)
def cli_entry():
"""Entry point for mindconverter CLI."""
permissions = os.R_OK | os.W_OK | os.X_OK
os.umask(permissions << 3 | permissions)
parser = argparse.ArgumentParser(
prog='mindconverter',
description='MindConverter CLI entry point (version: {})'.format(mindinsight.__version__))
parser.add_argument(
'--version',
action='version',
version='%(prog)s ({})'.format(mindinsight.__version__))
parser.add_argument(
'--in_file',
type=str,
action=InFileAction,
required=True,
help="""
Specify path for script file.
""")
parser.add_argument(
'--output',
type=str,
action=OutputDirAction,
default=os.path.join(os.getcwd(), 'output'),
help="""
Specify path for converted script file directory.
Default is output directory in the current working directory.
""")
parser.add_argument(
'--report',
type=str,
action=LogFileAction,
default=os.getcwd(),
help="""
Specify report directory. Default is the current working directory.
""")
argv = sys.argv[1:]
if not argv:
argv = ['-h']
args = parser.parse_args(argv)
else:
args = parser.parse_args()
mode = permissions << 6
os.makedirs(args.output, mode=mode, exist_ok=True)
os.makedirs(args.report, mode=mode, exist_ok=True)
_run(args.in_file, args.output, '', args.report)
def _run(in_files, out_dir, in_module, report):
"""
Run converter command.
Args:
in_files (str): The file path or directory to convert.
out_dir (str): The output directory to save converted file.
in_module (str): The module name to convert.
report (str): The report file path.
"""
files_config = {
'root_path': in_files if in_files else '',
'in_files': [],
'outfile_dir': out_dir,
'report_dir': report,
'in_module': in_module
}
if os.path.isfile(in_files):
files_config['root_path'] = os.path.dirname(in_files)
files_config['in_files'] = [in_files]
else:
for root_dir, _, files in os.walk(in_files):
for file in files:
files_config['in_files'].append(os.path.join(root_dir, file))
main(files_config)
...@@ -15,4 +15,4 @@ ...@@ -15,4 +15,4 @@
"""Create a logger.""" """Create a logger."""
from mindinsight.utils.log import setup_logger from mindinsight.utils.log import setup_logger
logger = setup_logger("mindconverter", "mindconverter") logger = setup_logger("mindconverter", "mindconverter", console=False)
...@@ -222,6 +222,7 @@ class MappingHelper: ...@@ -222,6 +222,7 @@ class MappingHelper:
def convert(self, call_name_pt: str, args_str_pt: str): def convert(self, call_name_pt: str, args_str_pt: str):
""" """
Convert code sentence to MindSpore code sentence. Convert code sentence to MindSpore code sentence.
Args: Args:
call_name_pt (str): str of the call function, etc. call_name_pt (str): str of the call function, etc.
args_str_pt (str): str of args for function, which starts with '(' and end with ')'. args_str_pt (str): str of args for function, which starts with '(' and end with ')'.
...@@ -336,11 +337,12 @@ def gen_explicit_map_one_delta(params_pt, args_pt, k_ms, k_pt): ...@@ -336,11 +337,12 @@ def gen_explicit_map_one_delta(params_pt, args_pt, k_ms, k_pt):
def load_json_file(file_path): def load_json_file(file_path):
""" """
Load data from given json file path. Load data from given json file path.
Args: Args:
file_path (str): The file to load json data from. file_path (str): The file to load json data from.
Returns: Returns:
list, the list data stored in file_path. list(str), the list data stored in file_path.
""" """
with open(file_path, 'r', encoding='utf-8') as file: with open(file_path, 'r', encoding='utf-8') as file:
info = json.loads(file.read()) info = json.loads(file.read())
......
...@@ -12,10 +12,10 @@ ...@@ -12,10 +12,10 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""main module""" """converter module"""
import inspect
import copy import copy
import importlib import importlib
import inspect
import os import os
import stat import stat
...@@ -24,12 +24,23 @@ from mindinsight.mindconverter.config import NN_LIST ...@@ -24,12 +24,23 @@ from mindinsight.mindconverter.config import NN_LIST
from mindinsight.mindconverter.config import ALL_TORCH_APIS from mindinsight.mindconverter.config import ALL_TORCH_APIS
from mindinsight.mindconverter.config import ALL_2P_LIST from mindinsight.mindconverter.config import ALL_2P_LIST
from mindinsight.mindconverter.config import UNSUPPORTED_WARN_INFOS from mindinsight.mindconverter.config import UNSUPPORTED_WARN_INFOS
from mindinsight.mindconverter.config import ALL_UNSUPPORTED
from mindinsight.mindconverter.common.log import logger from mindinsight.mindconverter.common.log import logger
from mindinsight.mindconverter.forward_call import ForwardCall
class Converter:
"""Convert class"""
convert_info = ''
flags = os.O_WRONLY | os.O_CREAT | os.O_EXCL
modes = stat.S_IWUSR | stat.S_IRUSR
def is_local_defined(obj, member): @staticmethod
def is_local_defined(obj, member):
""" """
Check if obj and member are both defined in the same source file. Check if obj and member are both defined in the same source file.
Args: Args:
obj (Union[object, module]): A module or a class. obj (Union[object, module]): A module or a class.
member (func): A function of obj. member (func): A function of obj.
...@@ -40,10 +51,11 @@ def is_local_defined(obj, member): ...@@ -40,10 +51,11 @@ def is_local_defined(obj, member):
srcfile = inspect.getsourcefile(obj) srcfile = inspect.getsourcefile(obj)
return inspect.getsourcefile(member) == srcfile return inspect.getsourcefile(member) == srcfile
@classmethod
def is_valid_module(obj, member): def is_valid_module(cls, obj, member):
""" """
Check if obj and member defined in same source file and member is inherited from torch.nn.Module. Check if obj and member defined in same source file and member is inherited from torch.nn.Module.
Args: Args:
obj (Union[object, module]): A module or a class. obj (Union[object, module]): A module or a class.
member (func): A function. member (func): A function.
...@@ -51,12 +63,21 @@ def is_valid_module(obj, member): ...@@ -51,12 +63,21 @@ def is_valid_module(obj, member):
Returns: Returns:
bool, True or False. bool, True or False.
""" """
return inspect.isclass(member) and (member.__base__.__name__ == 'Module') and is_local_defined(obj, member) if inspect.isclass(member):
is_subclass = member.__base__.__name__ in ['Module',
'Sequential',
def is_valid_function(obj, member): 'ModuleList',
'ModuleDict',
'ParameterList',
'ParameterDict']
return is_subclass and cls.is_local_defined(obj, member)
return False
@classmethod
def is_valid_function(cls, obj, member):
""" """
Check if member is function and defined in the file same as obj. Check if member is function and defined in the file same as obj.
Args: Args:
obj (Union[object, module]: The obj. obj (Union[object, module]: The obj.
member (func): The func. member (func): The func.
...@@ -64,15 +85,16 @@ def is_valid_function(obj, member): ...@@ -64,15 +85,16 @@ def is_valid_function(obj, member):
Returns: Returns:
bool, True or False. bool, True or False.
""" """
return inspect.isfunction(member) and is_local_defined(obj, member) return inspect.isfunction(member) and cls.is_local_defined(obj, member)
@staticmethod
def find_left_parentheses(string, right): def find_left_parentheses(string, right):
""" """
Find index of the first left parenthesis. Find index of the first left parenthesis.
Args: Args:
string (str): A line of code. string (str): A line of code.
right (int): Max index of string, same as `len(string) -1`. right (int): The right index for string to find from.
Returns: Returns:
int, index of the first parenthesis. int, index of the first parenthesis.
...@@ -92,10 +114,11 @@ def find_left_parentheses(string, right): ...@@ -92,10 +114,11 @@ def find_left_parentheses(string, right):
return i return i
raise ValueError("{} should contain ()".format(string)) raise ValueError("{} should contain ()".format(string))
@staticmethod
def find_right_parentheses(string, left): def find_right_parentheses(string, left):
""" """
Find first index of right parenthesis which make all left parenthesis make sense. Find first index of right parenthesis which make all left parenthesis make sense.
Args: Args:
string (str): A line of code. string (str): A line of code.
left (int): Start index of string to find from. left (int): Start index of string to find from.
...@@ -116,19 +139,19 @@ def find_right_parentheses(string, left): ...@@ -116,19 +139,19 @@ def find_right_parentheses(string, left):
return i return i
raise ValueError("{} should contain ()".format(string)) raise ValueError("{} should contain ()".format(string))
@staticmethod
def get_call_name(code, end): def get_call_name(code, end):
""" """
Traverse code in a reversed function from index end and get the call name and start index of the call name, if call Traverse code in a reversed function from index end and get the call name and start index of the call name,
name not found, return a null character string and -1 if call name not found, return a null character string and -1
Args: Args:
code (str): The str of code to find from. code (str): The str of code to find from.
end (int): Start index to find. end (int): Start index to find.
Returns: Returns:
str, founded api name if found, else a null character string. tuple(str, int), one is founded api name if found, else a null character string, the other is start index
int, start index of founded api name, -1 if api name not found of founded api name, -1 if api name not found
""" """
stack = [] stack = []
for i in range(end - 1, -1, -1): for i in range(end - 1, -1, -1):
...@@ -145,11 +168,10 @@ def get_call_name(code, end): ...@@ -145,11 +168,10 @@ def get_call_name(code, end):
return code[i + 1:end], i + 1 return code[i + 1:end], i + 1
return "", -1 return "", -1
def convert_api(self, code, start, api_name=""):
def convert_api(code, start, api_name=""):
""" """
Convert api_name in code to MindSpore api with start as a start index, if api_name is a python api, code will not Convert api_name in code to MindSpore api with start as a start index, if api_name is a python api,
convert. code will not convert.
Args: Args:
code (str): The str code to convert. code (str): The str code to convert.
...@@ -159,11 +181,10 @@ def convert_api(code, start, api_name=""): ...@@ -159,11 +181,10 @@ def convert_api(code, start, api_name=""):
Returns: Returns:
str, the converted code. str, the converted code.
int, index of converted api_name in code. int, index of converted api_name in code.
""" """
# handle format like .shape( # handle format like .shape(
if api_name.startswith('.'): if api_name.startswith('.'):
call_name, new_start = get_call_name(code, start) call_name, new_start = self.get_call_name(code, start)
if start == -1 or call_name == "self": if start == -1 or call_name == "self":
return code, start + 1 return code, start + 1
else: else:
...@@ -174,7 +195,7 @@ def convert_api(code, start, api_name=""): ...@@ -174,7 +195,7 @@ def convert_api(code, start, api_name=""):
left = code.find("(", start) left = code.find("(", start)
if left == -1: if left == -1:
raise ValueError('"(" not found, {} should work with "("'.format(call_name)) raise ValueError('"(" not found, {} should work with "("'.format(call_name))
right = find_right_parentheses(code, left) right = self.find_right_parentheses(code, left)
end = right end = right
expr = code[start:end + 1] expr = code[start:end + 1]
args_str = code[left:right + 1] args_str = code[left:right + 1]
...@@ -187,12 +208,14 @@ def convert_api(code, start, api_name=""): ...@@ -187,12 +208,14 @@ def convert_api(code, start, api_name=""):
code = code[:new_start] + new_expr + code[end + 1:next_newline] + ("\n" * fill_num) + code[next_newline:] code = code[:new_start] + new_expr + code[end + 1:next_newline] + ("\n" * fill_num) + code[next_newline:]
else: else:
code = code[:new_start] + new_expr + ")" + ("\n" * fill_num) + code[end + 2:] code = code[:new_start] + new_expr + ")" + ("\n" * fill_num) + code[end + 2:]
return code, start + len(map_helper.ms_api.name)
return code, start + len(map_helper.ms_api.name)
def find_api(code, i, is_forward): @staticmethod
def find_api(code, i, is_forward):
""" """
Find api name from code with a start index i, check api name ok with a is_forward condition. Find api name from code with a start index i, check api name ok with a is_forward condition.
Args: Args:
code (str): The code from which to find api name. code (str): The code from which to find api name.
i (int): The start index to find. i (int): The start index to find.
...@@ -212,10 +235,10 @@ def find_api(code, i, is_forward): ...@@ -212,10 +235,10 @@ def find_api(code, i, is_forward):
return api_name return api_name
return "" return ""
def convert_function(self, fun_name, fun, is_forward):
def convert_function(fun_name, fun, is_forward):
""" """
Convert a PyTorch function into MindSpore function. Convert a PyTorch function into MindSpore function.
Args: Args:
fun_name (str): The str of function name. fun_name (str): The str of function name.
fun (func): The function to convert. fun (func): The function to convert.
...@@ -232,20 +255,24 @@ def convert_function(fun_name, fun, is_forward): ...@@ -232,20 +255,24 @@ def convert_function(fun_name, fun, is_forward):
i = 0 i = 0
while i < len(code): while i < len(code):
api_name = find_api(code, i, is_forward) api_name = self.find_api(code, i, is_forward)
if api_name: if api_name:
line_no1 = line_no + code[:i].count('\n') line_no1 = line_no + code[:i].count('\n')
if api_name in ALL_MAPPING: if api_name in ALL_MAPPING:
logger.info("Line %3d start converting API: %s", line_no1, api_name) logger.info("Line %3d start converting API: %s", line_no1, api_name)
code, i = convert_api(code, i, api_name) code, i = self.convert_api(code, i, api_name)
self.convert_info += "[Convert][Line{:3d}] {} is converted.\n".format(line_no1, api_name)
continue continue
if api_name in ALL_UNSUPPORTED:
warn_info = ". " + UNSUPPORTED_WARN_INFOS[api_name] if api_name in UNSUPPORTED_WARN_INFOS else "" warn_info = ". " + UNSUPPORTED_WARN_INFOS[api_name] if api_name in UNSUPPORTED_WARN_INFOS else ""
logger.warning("Line %3d: found unsupported API: %s%s", line_no1, api_name, warn_info) logger.warning("Line %3d: found unsupported API: %s%s", line_no1, api_name, warn_info)
self.convert_info += "[Unconvert][Line{:3d}] {} didn't convert{}\n".format(line_no1,
api_name, warn_info)
i += 1 i += 1
return {code_saved: code} if code_saved != code else {} return {code_saved: code} if code_saved != code else {}
@staticmethod
def judge_forward(name, forward_list): def judge_forward(name, forward_list):
""" """
Check if function is a forward function. Check if function is a forward function.
...@@ -261,10 +288,10 @@ def judge_forward(name, forward_list): ...@@ -261,10 +288,10 @@ def judge_forward(name, forward_list):
logger.debug("%s is a forward function", name) logger.debug("%s is a forward function", name)
return is_forward return is_forward
def convert_module(self, module_name, module, forward_list):
def convert_module(module_name, module, forward_list):
""" """
Convert a PyTorch module code into MindSpore module code. Convert a PyTorch module code into MindSpore module code.
Args: Args:
module_name (str): The module's name. module_name (str): The module's name.
module (module): The module to convert. module (module): The module to convert.
...@@ -278,15 +305,15 @@ def convert_module(module_name, module, forward_list): ...@@ -278,15 +305,15 @@ def convert_module(module_name, module, forward_list):
mapped = {} mapped = {}
for name, member in inspect.getmembers(module): for name, member in inspect.getmembers(module):
if is_valid_function(module, member): if self.is_valid_function(module, member):
is_forward = judge_forward("{}.{}".format(module_name, name), forward_list) is_forward = self.judge_forward("{}.{}".format(module_name, name), forward_list)
mapped.update(convert_function(name, member, is_forward)) mapped.update(self.convert_function(name, member, is_forward))
return mapped return mapped
def get_mapping(self, import_mod, forward_list):
def get_mapping(import_mod, forward_list):
""" """
Convert code of a module and get mapping of old code and convert code. Convert code of a module and get mapping of old code and convert code.
Args: Args:
import_mod (module): The module to convert. import_mod (module): The module to convert.
forward_list (set): A set of forward function. forward_list (set): A set of forward function.
...@@ -297,36 +324,40 @@ def get_mapping(import_mod, forward_list): ...@@ -297,36 +324,40 @@ def get_mapping(import_mod, forward_list):
mapping = {} mapping = {}
tasks = [] tasks = []
for name, member in inspect.getmembers(import_mod): for name, member in inspect.getmembers(import_mod):
if is_valid_module(import_mod, member): if self.is_valid_module(import_mod, member):
_, line_no = inspect.getsourcelines(member) _, line_no = inspect.getsourcelines(member)
tasks.append((line_no, convert_module, (name, member, forward_list))) tasks.append((line_no, self.convert_module, (name, member, forward_list)))
elif is_valid_function(import_mod, member): elif self.is_valid_function(import_mod, member):
_, line_no = inspect.getsourcelines(member) _, line_no = inspect.getsourcelines(member)
is_forward = judge_forward("{}.{}".format(import_mod, name), forward_list) is_forward = self.judge_forward("{}.{}".format(import_mod, name), forward_list)
tasks.append((line_no, convert_function, (name, member, is_forward))) tasks.append((line_no, self.convert_function, (name, member, is_forward)))
tasks.sort() tasks.sort()
for _, convert_fun, args in tasks: for _, convert_fun, args in tasks:
mapping.update(convert_fun(*args)) mapping.update(convert_fun(*args))
return mapping return mapping
def convert(self, import_name, output_dir, report_dir):
def convert(import_name, nn_module):
""" """
The entrance for convert a module's code, code converted will be write to file called out.py. Convert a module's code, code converted will be save in output_dir, and a report will be save in report_dir.
Args: Args:
import_name (str): The module from which to import the module to convert. import_name (str): The module from which to import the module to convert.
nn_module (str): Name of the module to convert. output_dir (str): The path to save converted file.
report_dir (str): The path to save report file.
""" """
logger.info("Start converting %s.%s", import_name, nn_module) logger.info("Start converting %s", import_name)
self.convert_info += '[Start Convert]\nThe module is {}\n'.format(import_name)
import_mod = importlib.import_module(import_name) import_mod = importlib.import_module(import_name)
forward_list = set() srcfile = inspect.getsourcefile(import_mod)
logger.info("Script file is %s", srcfile)
forward_list = set(ForwardCall(srcfile).calls)
logger.debug("Forward_list: %s", forward_list) logger.debug("Forward_list: %s", forward_list)
# replace python function under nn.Modlue # replace python function under nn.Module
mapping = get_mapping(import_mod, forward_list) mapping = self.get_mapping(import_mod, forward_list)
code = inspect.getsource(import_mod) code = inspect.getsource(import_mod)
for key, value in mapping.items(): for key, value in mapping.items():
...@@ -335,6 +366,12 @@ def convert(import_name, nn_module): ...@@ -335,6 +366,12 @@ def convert(import_name, nn_module):
code = 'import mindspore.ops.operations as P\n' + code code = 'import mindspore.ops.operations as P\n' + code
code = 'import mindspore.nn as nn\n' + code code = 'import mindspore.nn as nn\n' + code
code = 'import mindspore\n' + code code = 'import mindspore\n' + code
self.convert_info += '||[Import Add] Add follow import sentences:\n'
self.convert_info += 'import mindspore.ops.operations as P\n'
self.convert_info += 'import mindspore.nn as nn\n'
self.convert_info += 'import mindspore\n\n'
code = code.replace('import torch', '# import torch') code = code.replace('import torch', '# import torch')
code = code.replace('from torch', '# from torch') code = code.replace('from torch', '# from torch')
code = code.replace('(nn.Module):', '(nn.Cell):') code = code.replace('(nn.Module):', '(nn.Cell):')
...@@ -343,13 +380,82 @@ def convert(import_name, nn_module): ...@@ -343,13 +380,82 @@ def convert(import_name, nn_module):
code = code.replace('(nn.Sequential)', '(nn.SequentialCell)') code = code.replace('(nn.Sequential)', '(nn.SequentialCell)')
code = code.replace('nn.init.', 'pass # nn.init.') code = code.replace('nn.init.', 'pass # nn.init.')
flags = os.O_WRONLY | os.O_CREAT | os.O_EXCL self.convert_info += '||[Import Annotated] Annotated follow import sentences:\n'
modes = stat.S_IWUSR | stat.S_IRUSR self.convert_info += 'import sentence on torch as follows are annotated:\n'
with os.fdopen(os.open('out.py', flags, modes), 'w') as file: self.convert_info += 'import torch\n'
self.convert_info += 'from torch ...\n'
self.convert_info += '||[Explicit Convert] Module or function are explicitly converted as follows:\n'
self.convert_info += '[nn.Module] is converted to [nn.Cell]\n'
self.convert_info += '[forward] is converted to [construct]\n'
self.convert_info += '[nn.Linear] is converted to [nn.Dense]\n'
self.convert_info += '[nn.Sequential] is converted to [nn.SequentialCell]\n'
self.convert_info += '[nn.init] is not converted and annotated\n'
self.convert_info += '[Convert over]'
dest_file = os.path.join(output_dir, os.path.basename(srcfile))
with os.fdopen(os.open(dest_file, self.flags, self.modes), 'w') as file:
file.write(code) file.write(code)
logger.info("Convert success. Result is wrote to out.py\n") logger.info("Convert success. Result is wrote to %s.", dest_file)
dest_report_file = os.path.join(report_dir,
'_'.join(os.path.basename(srcfile).split('.')[:-1]) + '_report.txt')
with os.fdopen(os.open(dest_report_file, self.flags, self.modes), 'a') as file:
file.write(self.convert_info)
logger.info("Convert report is saved in %s", dest_report_file)
if __name__ == '__main__': def _get_name_ext(file):
"""
Split a file name in name and extension.
Args:
file (str): Full file path.
Returns:
tuple (str, str), name and extension.
"""
_, name = os.path.split(file)
return os.path.splitext(name)
convert('torchvision.models.resnet', 'resnet18') def _path_split(file):
"""
Split a path in head and tail.
Args:
file (str): The file path.
Returns:
list[str], list of file tail
"""
file_dir, name = os.path.split(file)
if file_dir:
sep = file[len(file_dir)-1]
if file_dir.startswith(sep):
return file.split(sep)[1:]
return file.split(sep)
return [name]
def main(files_config):
"""
The entrance for converter, script files will be converted.
Args:
files_config (dict): The config of files which to convert.
"""
convert_ins = Converter()
root_path = files_config['root_path']
in_files = files_config['in_files']
for in_file in in_files:
in_file_split = _path_split(in_file[len(root_path):])
in_file_split[-1], _ = _get_name_ext(in_file_split[-1])
module_name = '.'.join(in_file_split)
convert_ins.convert(module_name, files_config['outfile_dir'], files_config['report_dir'])
in_module = files_config['in_module']
if in_module:
convert_ins.convert(in_module, files_config['outfile_dir'], files_config['report_dir'])
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Find out forward functions of script file"""
import ast
import os
class ForwardCall(ast.NodeVisitor):
"""
AST visitor that processes forward calls.
Find the sub functions called by the forward function in the script file.
"""
def __init__(self, filename):
self.filename = filename
self.module_name = os.path.basename(filename).replace('.py', '')
self.name_stack = []
self.forward_stack = []
self.calls = []
self.process()
def process(self):
"""Parse the python source file to find the forward functions."""
with open(self.filename, 'rt', encoding='utf-8') as file:
content = file.read()
self.visit(ast.parse(content, self.filename))
def get_current_namespace(self):
"""Get the namespace when visit the AST node"""
namespace = '.'.join(self.name_stack)
return namespace
@classmethod
def get_ast_node_name(cls, node):
"""Get AST node name."""
if isinstance(node, ast.Attribute):
return f'{cls.get_ast_node_name(node.value)}.{node.attr}'
if isinstance(node, ast.Name):
return node.id
return node
def visit_ClassDef(self, node):
"""Callback function when visit AST tree"""
self.name_stack.append(node.name)
self.generic_visit(node)
self.name_stack.pop()
def visit_FunctionDef(self, node):
"""Callback function when visit AST tree"""
func_name = f'{self.get_current_namespace()}.{node.name}'
is_in_chain = func_name in self.calls or node.name == 'forward'
if is_in_chain:
self.forward_stack.append(func_name)
if node.name == 'forward':
self.calls.append(func_name)
self.generic_visit(node)
if is_in_chain:
self.forward_stack.pop()
def visit_Call(self, node):
"""Callback function when visit AST tree"""
for arg in node.args:
self.visit(arg)
for kw in node.keywords:
self.visit(kw.value)
func_name = self.get_ast_node_name(node.func)
if isinstance(node.func, ast.Name):
if func_name not in ['super', 'str', 'repr']:
if self.forward_stack:
self.calls.append(func_name)
self.visit(node.func)
else:
if self.forward_stack:
if 'self' in func_name:
self.calls.append(f'{self.get_current_namespace()}.{func_name.split(".")[-1]}')
else:
self.calls.append(func_name)
self.visit(node.func)
...@@ -194,6 +194,7 @@ if __name__ == '__main__': ...@@ -194,6 +194,7 @@ if __name__ == '__main__':
entry_points={ entry_points={
'console_scripts': [ 'console_scripts': [
'mindinsight=mindinsight.utils.command:main', 'mindinsight=mindinsight.utils.command:main',
'mindconverter=mindinsight.mindconverter.cli:cli_entry',
], ],
}, },
python_requires='>=3.7', python_requires='>=3.7',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册