diff --git a/mindinsight/mindconverter/README.md b/mindinsight/mindconverter/README.md new file mode 100644 index 0000000000000000000000000000000000000000..5d7d598a4aded3ef7a2b59334560da7512cfdadd --- /dev/null +++ b/mindinsight/mindconverter/README.md @@ -0,0 +1,82 @@ +### Introduction + +MindConverter is a tool that converting PyTorch scripts to MindSpore scripts. With minial manual editing and the guidance from conversion reports, users may easily migrate their model from PyTorch framework to MindSpore. + + + +### System Requirements + +* PyTorch v1.5.0 +* MindSpore v0.2.0 + +### Installation + +This tool is part of MindInsight and accessible to users after installing MindInsight, no extra installation is needed. + +### Commandline Usage +Set the model scripts directory as the PYTHONPATH environment variable first: +```buildoutcfg +export PYTHONPATH= +``` + +mindconverter commandline usage: +```buildoutcfg +mindconverter [-h] [--version] --in_file IN_FILE [--output OUTPUT] + [--report REPORT] + +MindConverter CLI entry point (version: 0.2.0) + +optional arguments: + -h, --help show this help message and exit + --version show program's version number and exit + --in_file IN_FILE Specify path for script file. + --output OUTPUT Specify path for converted script file directory. Default + is output directory in the current working directory. + --report REPORT Specify report directory. Default is the current working + directorys +``` + +Usage example: +```buildoutcfg +export PYTHONPATH=~/my_pt_proj/models +mindconverter --in_file lenet.py +``` + +Since the conversion is not 100% flawless, we encourage users to checkout the reports when fixing issues of the converted scripts. + + +### Unsupported Situation #1 +Classes and functions that can't be converted: +* The use of shape, ndim and dtype member of torch.Tensor. +* torch.nn.AdaptiveXXXPoolXd and torch.nn.functional.adaptive_XXX_poolXd() +* torch.nn.functional.Dropout +* torch.unsqueeze() and torch.Tensor.unsqueeze() +* torch.chunk() and torch.Tensor.chunk() + +### Unsupported Situation #2 + +Subclassing from the subclasses of nn.Module + +e.g. (code snip from torchvision,models.mobilenet) + +```python +from torch import nn + +class ConvBNReLU(nn.Sequential): + def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1): + padding = (kernel_size - 1) // 2 + super(ConvBNReLU, self).__init__( + nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False), + nn.BatchNorm2d(out_planes), + nn.ReLU6(inplace=True) + ) +``` + +### Unsupported Situation #3 + +Unconventional import naming + +e.g. +```python +import torch.nn as mm +``` \ No newline at end of file diff --git a/mindinsight/mindconverter/cli.py b/mindinsight/mindconverter/cli.py new file mode 100644 index 0000000000000000000000000000000000000000..7e0a7d67b55a1cbba8f59c2a100f462ee110c56a --- /dev/null +++ b/mindinsight/mindconverter/cli.py @@ -0,0 +1,216 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Command module.""" +import os +import sys +import argparse + +import mindinsight +from mindinsight.mindconverter.converter import main + + +class FileDirAction(argparse.Action): + """File directory action class definition.""" + + @staticmethod + def check_path(parser, values, option_string=None): + """ + Check argument for file path. + + Args: + parser (ArgumentParser): Passed-in argument parser. + values (object): Argument values with type depending on argument definition. + option_string (str): Optional string for specific argument name. Default: None. + """ + outfile = values + if outfile.startswith('~'): + outfile = os.path.realpath(os.path.expanduser(outfile)) + + if not outfile.startswith('/'): + outfile = os.path.realpath(os.path.join(os.getcwd(), outfile)) + + if os.path.exists(outfile) and not os.access(outfile, os.R_OK): + parser.error(f'{option_string} {outfile} not accessible') + return outfile + + def __call__(self, parser, namespace, values, option_string=None): + """ + Inherited __call__ method from argparse.Action. + + Args: + parser (ArgumentParser): Passed-in argument parser. + namespace (Namespace): Namespace object to hold arguments. + values (object): Argument values with type depending on argument definition. + option_string (str): Optional string for specific argument name. Default: None. + """ + outfile_dir = self.check_path(parser, values, option_string) + if os.path.isfile(outfile_dir): + parser.error(f'{option_string} {outfile_dir} is a file') + + setattr(namespace, self.dest, outfile_dir) + + +class OutputDirAction(argparse.Action): + """File directory action class definition.""" + + def __call__(self, parser, namespace, values, option_string=None): + """ + Inherited __call__ method from argparse.Action. + + Args: + parser (ArgumentParser): Passed-in argument parser. + namespace (Namespace): Namespace object to hold arguments. + values (object): Argument values with type depending on argument definition. + option_string (str): Optional string for specific argument name. Default: None. + """ + output = values + if output.startswith('~'): + output = os.path.realpath(os.path.expanduser(output)) + + if not output.startswith('/'): + output = os.path.realpath(os.path.join(os.getcwd(), output)) + + if os.path.exists(output): + if not os.access(output, os.R_OK): + parser.error(f'{option_string} {output} not accessible') + + if os.path.isfile(output): + parser.error(f'{option_string} {output} is a file') + + setattr(namespace, self.dest, output) + + +class InFileAction(argparse.Action): + """Input File action class definition.""" + + def __call__(self, parser, namespace, values, option_string=None): + """ + Inherited __call__ method from argparse.Action. + + Args: + parser (ArgumentParser): Passed-in argument parser. + namespace (Namespace): Namespace object to hold arguments. + values (object): Argument values with type depending on argument definition. + option_string (str): Optional string for specific argument name. Default: None. + """ + outfile_dir = FileDirAction.check_path(parser, values, option_string) + if not os.path.exists(outfile_dir): + parser.error(f'{option_string} {outfile_dir} not exists') + + if not os.path.isfile(outfile_dir): + parser.error(f'{option_string} {outfile_dir} is not a file') + + setattr(namespace, self.dest, outfile_dir) + + +class LogFileAction(argparse.Action): + """Log file action class definition.""" + + def __call__(self, parser, namespace, values, option_string=None): + """ + Inherited __call__ method from FileDirAction. + + Args: + parser (ArgumentParser): Passed-in argument parser. + namespace (Namespace): Namespace object to hold arguments. + values (object): Argument values with type depending on argument definition. + option_string (str): Optional string for specific argument name. Default: None. + """ + outfile_dir = FileDirAction.check_path(parser, values, option_string) + if os.path.exists(outfile_dir) and not os.path.isdir(outfile_dir): + parser.error(f'{option_string} {outfile_dir} is not a directory') + setattr(namespace, self.dest, outfile_dir) + + +def cli_entry(): + """Entry point for mindconverter CLI.""" + + permissions = os.R_OK | os.W_OK | os.X_OK + os.umask(permissions << 3 | permissions) + + parser = argparse.ArgumentParser( + prog='mindconverter', + description='MindConverter CLI entry point (version: {})'.format(mindinsight.__version__)) + + parser.add_argument( + '--version', + action='version', + version='%(prog)s ({})'.format(mindinsight.__version__)) + + parser.add_argument( + '--in_file', + type=str, + action=InFileAction, + required=True, + help=""" + Specify path for script file. + """) + + parser.add_argument( + '--output', + type=str, + action=OutputDirAction, + default=os.path.join(os.getcwd(), 'output'), + help=""" + Specify path for converted script file directory. + Default is output directory in the current working directory. + """) + + parser.add_argument( + '--report', + type=str, + action=LogFileAction, + default=os.getcwd(), + help=""" + Specify report directory. Default is the current working directory. + """) + + argv = sys.argv[1:] + if not argv: + argv = ['-h'] + args = parser.parse_args(argv) + else: + args = parser.parse_args() + mode = permissions << 6 + os.makedirs(args.output, mode=mode, exist_ok=True) + os.makedirs(args.report, mode=mode, exist_ok=True) + _run(args.in_file, args.output, '', args.report) + + +def _run(in_files, out_dir, in_module, report): + """ + Run converter command. + + Args: + in_files (str): The file path or directory to convert. + out_dir (str): The output directory to save converted file. + in_module (str): The module name to convert. + report (str): The report file path. + """ + files_config = { + 'root_path': in_files if in_files else '', + 'in_files': [], + 'outfile_dir': out_dir, + 'report_dir': report, + 'in_module': in_module + } + if os.path.isfile(in_files): + files_config['root_path'] = os.path.dirname(in_files) + files_config['in_files'] = [in_files] + else: + for root_dir, _, files in os.walk(in_files): + for file in files: + files_config['in_files'].append(os.path.join(root_dir, file)) + main(files_config) diff --git a/mindinsight/mindconverter/common/log.py b/mindinsight/mindconverter/common/log.py index 925946985262e66228003e32a7cb7cbd525c5100..e66c382bd6ec5b229076200b19d9bac26aecb8f9 100644 --- a/mindinsight/mindconverter/common/log.py +++ b/mindinsight/mindconverter/common/log.py @@ -15,4 +15,4 @@ """Create a logger.""" from mindinsight.utils.log import setup_logger -logger = setup_logger("mindconverter", "mindconverter") +logger = setup_logger("mindconverter", "mindconverter", console=False) diff --git a/mindinsight/mindconverter/config.py b/mindinsight/mindconverter/config.py index 5f000f0159e61f297c7f766a7b759f81791c196f..3ebc44db87bd0174d5e812c62e08dedd65a988f6 100644 --- a/mindinsight/mindconverter/config.py +++ b/mindinsight/mindconverter/config.py @@ -67,7 +67,7 @@ class APIPt: Raises: ValueError: If can not use ast to parse or the required parse node not type of ast.Call, - or the given args_str not valid. + or the given args_str not valid. """ # expr is REQUIRED to meet (**) format if not (len(args_str) >= 2 and args_str[0] == "(" and args_str[-1] == ")"): @@ -222,6 +222,7 @@ class MappingHelper: def convert(self, call_name_pt: str, args_str_pt: str): """ Convert code sentence to MindSpore code sentence. + Args: call_name_pt (str): str of the call function, etc. args_str_pt (str): str of args for function, which starts with '(' and end with ')'. @@ -336,11 +337,12 @@ def gen_explicit_map_one_delta(params_pt, args_pt, k_ms, k_pt): def load_json_file(file_path): """ Load data from given json file path. + Args: file_path (str): The file to load json data from. Returns: - list, the list data stored in file_path. + list(str), the list data stored in file_path. """ with open(file_path, 'r', encoding='utf-8') as file: info = json.loads(file.read()) diff --git a/mindinsight/mindconverter/converter.py b/mindinsight/mindconverter/converter.py index ddcd128a8acbad4e9e3fc8335f55cde32484f8df..0addae6cd59cc4c64c295072d1952bc0be0bf6ef 100644 --- a/mindinsight/mindconverter/converter.py +++ b/mindinsight/mindconverter/converter.py @@ -12,10 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ -"""main module""" -import inspect +"""converter module""" import copy import importlib +import inspect import os import stat @@ -24,332 +24,438 @@ from mindinsight.mindconverter.config import NN_LIST from mindinsight.mindconverter.config import ALL_TORCH_APIS from mindinsight.mindconverter.config import ALL_2P_LIST from mindinsight.mindconverter.config import UNSUPPORTED_WARN_INFOS +from mindinsight.mindconverter.config import ALL_UNSUPPORTED from mindinsight.mindconverter.common.log import logger +from mindinsight.mindconverter.forward_call import ForwardCall -def is_local_defined(obj, member): - """ - Check if obj and member are both defined in the same source file. - Args: - obj (Union[object, module]): A module or a class. - member (func): A function of obj. - - Returns: - bool, True or False. - """ - srcfile = inspect.getsourcefile(obj) - return inspect.getsourcefile(member) == srcfile - - -def is_valid_module(obj, member): - """ - Check if obj and member defined in same source file and member is inherited from torch.nn.Module. - Args: - obj (Union[object, module]): A module or a class. - member (func): A function. - - Returns: - bool, True or False. - """ - return inspect.isclass(member) and (member.__base__.__name__ == 'Module') and is_local_defined(obj, member) - - -def is_valid_function(obj, member): - """ - Check if member is function and defined in the file same as obj. - Args: - obj (Union[object, module]: The obj. - member (func): The func. - - Returns: - bool, True or False. - """ - return inspect.isfunction(member) and is_local_defined(obj, member) - - -def find_left_parentheses(string, right): - """ - Find index of the first left parenthesis. - Args: - string (str): A line of code. - right (int): Max index of string, same as `len(string) -1`. - - Returns: - int, index of the first parenthesis. - - Raises: - ValueError: If line of code doesn't contain any pair of `()` or `(` and `)` are not paired. - """ - if string[right] != ')': - raise ValueError('code [{}] at index {} not ")".'.format(string, right)) - stack = [] - for i in range(right, -1, -1): - if string[i] == ')': - stack.append(')') - elif string[i] == '(': - stack.pop() - if not stack: - return i - raise ValueError("{} should contain ()".format(string)) - - -def find_right_parentheses(string, left): - """ - Find first index of right parenthesis which make all left parenthesis make sense. - Args: - string (str): A line of code. - left (int): Start index of string to find from. - - Returns: - int, index of the found right parenthesis. - - Raises: - ValueError: If line of code doesn't contain any pair of `()` or `(` and `)` are not paired. - """ - stack = [] - for i in range(left, len(string)): - if string[i] == '(': - stack.append('(') - elif string[i] == ')': - stack.pop() - if not stack: - return i - raise ValueError("{} should contain ()".format(string)) - - -def get_call_name(code, end): - """ - Traverse code in a reversed function from index end and get the call name and start index of the call name, if call - name not found, return a null character string and -1 +class Converter: + """Convert class""" - Args: - code (str): The str of code to find from. - end (int): Start index to find. + convert_info = '' + flags = os.O_WRONLY | os.O_CREAT | os.O_EXCL + modes = stat.S_IWUSR | stat.S_IRUSR - Returns: - str, founded api name if found, else a null character string. - int, start index of founded api name, -1 if api name not found - """ - stack = [] - for i in range(end - 1, -1, -1): - if code[i] in ["(", "[", "{"]: - if stack: + @staticmethod + def is_local_defined(obj, member): + """ + Check if obj and member are both defined in the same source file. + + Args: + obj (Union[object, module]): A module or a class. + member (func): A function of obj. + + Returns: + bool, True or False. + """ + srcfile = inspect.getsourcefile(obj) + return inspect.getsourcefile(member) == srcfile + + @classmethod + def is_valid_module(cls, obj, member): + """ + Check if obj and member defined in same source file and member is inherited from torch.nn.Module. + + Args: + obj (Union[object, module]): A module or a class. + member (func): A function. + + Returns: + bool, True or False. + """ + if inspect.isclass(member): + is_subclass = member.__base__.__name__ in ['Module', + 'Sequential', + 'ModuleList', + 'ModuleDict', + 'ParameterList', + 'ParameterDict'] + return is_subclass and cls.is_local_defined(obj, member) + return False + + @classmethod + def is_valid_function(cls, obj, member): + """ + Check if member is function and defined in the file same as obj. + + Args: + obj (Union[object, module]: The obj. + member (func): The func. + + Returns: + bool, True or False. + """ + return inspect.isfunction(member) and cls.is_local_defined(obj, member) + + @staticmethod + def find_left_parentheses(string, right): + """ + Find index of the first left parenthesis. + + Args: + string (str): A line of code. + right (int): The right index for string to find from. + + Returns: + int, index of the first parenthesis. + + Raises: + ValueError: If line of code doesn't contain any pair of `()` or `(` and `)` are not paired. + """ + if string[right] != ')': + raise ValueError('code [{}] at index {} not ")".'.format(string, right)) + stack = [] + for i in range(right, -1, -1): + if string[i] == ')': + stack.append(')') + elif string[i] == '(': + stack.pop() + if not stack: + return i + raise ValueError("{} should contain ()".format(string)) + + @staticmethod + def find_right_parentheses(string, left): + """ + Find first index of right parenthesis which make all left parenthesis make sense. + + Args: + string (str): A line of code. + left (int): Start index of string to find from. + + Returns: + int, index of the found right parenthesis. + + Raises: + ValueError: If line of code doesn't contain any pair of `()` or `(` and `)` are not paired. + """ + stack = [] + for i in range(left, len(string)): + if string[i] == '(': + stack.append('(') + elif string[i] == ')': stack.pop() - else: + if not stack: + return i + raise ValueError("{} should contain ()".format(string)) + + @staticmethod + def get_call_name(code, end): + """ + Traverse code in a reversed function from index end and get the call name and start index of the call name, + if call name not found, return a null character string and -1 + + Args: + code (str): The str of code to find from. + end (int): Start index to find. + + Returns: + tuple(str, int), one is founded api name if found, else a null character string, the other is start index + of founded api name, -1 if api name not found + """ + stack = [] + for i in range(end - 1, -1, -1): + if code[i] in ["(", "[", "{"]: + if stack: + stack.pop() + else: + return code[i + 1:end], i + 1 + elif code[i] in [")", "]", "}"]: + stack.append(code[i]) + elif stack: + continue + elif not (code[i].isalpha() or code[i].isdigit() or code[i] == '_' or code[i] == '.'): return code[i + 1:end], i + 1 - elif code[i] in [")", "]", "}"]: - stack.append(code[i]) - elif stack: - continue - elif not (code[i].isalpha() or code[i].isdigit() or code[i] == '_' or code[i] == '.'): - return code[i + 1:end], i + 1 - return "", -1 - - -def convert_api(code, start, api_name=""): - """ - Convert api_name in code to MindSpore api with start as a start index, if api_name is a python api, code will not - convert. - - Args: - code (str): The str code to convert. - start (int): The index of code to start convert from. - api_name (str): The api name to convert. - - Returns: - str, the converted code. - int, index of converted api_name in code. - - """ - # handle format like .shape( - if api_name.startswith('.'): - call_name, new_start = get_call_name(code, start) - if start == -1 or call_name == "self": - return code, start + 1 - else: - call_name = api_name - new_start = start - - # find full api expected to be converted. eg:expr="nn.Conv2d(1,2,3)" args_str="(1,2,3)" - left = code.find("(", start) - if left == -1: - raise ValueError('"(" not found, {} should work with "("'.format(call_name)) - right = find_right_parentheses(code, left) - end = right - expr = code[start:end + 1] - args_str = code[left:right + 1] - - map_helper = ALL_MAPPING[api_name] - new_expr = map_helper.convert(call_name, args_str) - next_newline = code.find("\n", end + 1) - fill_num = (expr.count("\n") - new_expr.count("\n")) - if next_newline != -1: - code = code[:new_start] + new_expr + code[end + 1:next_newline] + ("\n" * fill_num) + code[next_newline:] - else: - code = code[:new_start] + new_expr + ")" + ("\n" * fill_num) + code[end + 2:] - return code, start + len(map_helper.ms_api.name) - - -def find_api(code, i, is_forward): + return "", -1 + + def convert_api(self, code, start, api_name=""): + """ + Convert api_name in code to MindSpore api with start as a start index, if api_name is a python api, + code will not convert. + + Args: + code (str): The str code to convert. + start (int): The index of code to start convert from. + api_name (str): The api name to convert. + + Returns: + str, the converted code. + int, index of converted api_name in code. + """ + # handle format like .shape( + if api_name.startswith('.'): + call_name, new_start = self.get_call_name(code, start) + if start == -1 or call_name == "self": + return code, start + 1 + else: + call_name = api_name + new_start = start + + # find full api expected to be converted. eg:expr="nn.Conv2d(1,2,3)" args_str="(1,2,3)" + left = code.find("(", start) + if left == -1: + raise ValueError('"(" not found, {} should work with "("'.format(call_name)) + right = self.find_right_parentheses(code, left) + end = right + expr = code[start:end + 1] + args_str = code[left:right + 1] + + map_helper = ALL_MAPPING[api_name] + new_expr = map_helper.convert(call_name, args_str) + next_newline = code.find("\n", end + 1) + fill_num = (expr.count("\n") - new_expr.count("\n")) + if next_newline != -1: + code = code[:new_start] + new_expr + code[end + 1:next_newline] + ("\n" * fill_num) + code[next_newline:] + else: + code = code[:new_start] + new_expr + ")" + ("\n" * fill_num) + code[end + 2:] + + return code, start + len(map_helper.ms_api.name) + + @staticmethod + def find_api(code, i, is_forward): + """ + Find api name from code with a start index i, check api name ok with a is_forward condition. + + Args: + code (str): The code from which to find api name. + i (int): The start index to find. + is_forward (bool): Check if the found api name ok. + + Returns: + str, api name if find api name and check ok with is_forward condition, else a null character string. + """ + if code[i:].startswith("nn.") \ + or code[i:].startswith("F.") \ + or code[i:].startswith("torch.") \ + or code[i:].startswith('.'): + j = code.find('(', i) + if j != -1 and code[i:j] in ALL_TORCH_APIS: + api_name = code[i:j] + if (not is_forward and api_name in NN_LIST) or (is_forward and api_name in ALL_2P_LIST): + return api_name + return "" + + def convert_function(self, fun_name, fun, is_forward): + """ + Convert a PyTorch function into MindSpore function. + + Args: + fun_name (str): The str of function name. + fun (func): The function to convert. + is_forward (bool): If the function is defined in forward function in nn.Module in torch. + + Returns: + dict, old code and converted code map if convert happens, else {}. + """ + _, line_no = inspect.getsourcelines(fun) + logger.info("Line %3d: start converting function %s()", line_no, fun_name) + + code = inspect.getsource(fun) + code_saved = copy.copy(code) + + i = 0 + while i < len(code): + api_name = self.find_api(code, i, is_forward) + if api_name: + line_no1 = line_no + code[:i].count('\n') + if api_name in ALL_MAPPING: + logger.info("Line %3d start converting API: %s", line_no1, api_name) + code, i = self.convert_api(code, i, api_name) + self.convert_info += "[Convert][Line{:3d}] {} is converted.\n".format(line_no1, api_name) + continue + if api_name in ALL_UNSUPPORTED: + warn_info = ". " + UNSUPPORTED_WARN_INFOS[api_name] if api_name in UNSUPPORTED_WARN_INFOS else "" + logger.warning("Line %3d: found unsupported API: %s%s", line_no1, api_name, warn_info) + self.convert_info += "[Unconvert][Line{:3d}] {} didn't convert{}\n".format(line_no1, + api_name, warn_info) + i += 1 + return {code_saved: code} if code_saved != code else {} + + @staticmethod + def judge_forward(name, forward_list): + """ + Check if function is a forward function. + + Args: + name (str): The function name. + forward_list (set): A set of forward function. + + Returns: + bool, True or False + """ + is_forward = name in forward_list or name.split(".")[-1] == "forward" + if is_forward: + logger.debug("%s is a forward function", name) + return is_forward + + def convert_module(self, module_name, module, forward_list): + """ + Convert a PyTorch module code into MindSpore module code. + + Args: + module_name (str): The module's name. + module (module): The module to convert. + forward_list (set): A set of forward function. + + Returns: + dict, map of old code and converted code. + """ + _, line_no = inspect.getsourcelines(module) + logger.info("Line {:3d}: start converting nn.Module {}".format(line_no, module_name)) + + mapped = {} + for name, member in inspect.getmembers(module): + if self.is_valid_function(module, member): + is_forward = self.judge_forward("{}.{}".format(module_name, name), forward_list) + mapped.update(self.convert_function(name, member, is_forward)) + return mapped + + def get_mapping(self, import_mod, forward_list): + """ + Convert code of a module and get mapping of old code and convert code. + + Args: + import_mod (module): The module to convert. + forward_list (set): A set of forward function. + + Returns: + dict, mapping for old code and converted code of the module + """ + mapping = {} + tasks = [] + for name, member in inspect.getmembers(import_mod): + if self.is_valid_module(import_mod, member): + _, line_no = inspect.getsourcelines(member) + tasks.append((line_no, self.convert_module, (name, member, forward_list))) + elif self.is_valid_function(import_mod, member): + _, line_no = inspect.getsourcelines(member) + is_forward = self.judge_forward("{}.{}".format(import_mod, name), forward_list) + tasks.append((line_no, self.convert_function, (name, member, is_forward))) + tasks.sort() + for _, convert_fun, args in tasks: + mapping.update(convert_fun(*args)) + return mapping + + def convert(self, import_name, output_dir, report_dir): + """ + Convert a module's code, code converted will be save in output_dir, and a report will be save in report_dir. + + Args: + import_name (str): The module from which to import the module to convert. + output_dir (str): The path to save converted file. + report_dir (str): The path to save report file. + """ + logger.info("Start converting %s", import_name) + self.convert_info += '[Start Convert]\nThe module is {}\n'.format(import_name) + + import_mod = importlib.import_module(import_name) + + srcfile = inspect.getsourcefile(import_mod) + logger.info("Script file is %s", srcfile) + + forward_list = set(ForwardCall(srcfile).calls) + logger.debug("Forward_list: %s", forward_list) + + # replace python function under nn.Module + mapping = self.get_mapping(import_mod, forward_list) + + code = inspect.getsource(import_mod) + for key, value in mapping.items(): + code = code.replace(key, value) + + code = 'import mindspore.ops.operations as P\n' + code + code = 'import mindspore.nn as nn\n' + code + code = 'import mindspore\n' + code + + self.convert_info += '||[Import Add] Add follow import sentences:\n' + self.convert_info += 'import mindspore.ops.operations as P\n' + self.convert_info += 'import mindspore.nn as nn\n' + self.convert_info += 'import mindspore\n\n' + + code = code.replace('import torch', '# import torch') + code = code.replace('from torch', '# from torch') + code = code.replace('(nn.Module):', '(nn.Cell):') + code = code.replace('forward(', 'construct(') + code = code.replace('nn.Linear', 'nn.Dense') + code = code.replace('(nn.Sequential)', '(nn.SequentialCell)') + code = code.replace('nn.init.', 'pass # nn.init.') + + self.convert_info += '||[Import Annotated] Annotated follow import sentences:\n' + self.convert_info += 'import sentence on torch as follows are annotated:\n' + self.convert_info += 'import torch\n' + self.convert_info += 'from torch ...\n' + + self.convert_info += '||[Explicit Convert] Module or function are explicitly converted as follows:\n' + self.convert_info += '[nn.Module] is converted to [nn.Cell]\n' + self.convert_info += '[forward] is converted to [construct]\n' + self.convert_info += '[nn.Linear] is converted to [nn.Dense]\n' + self.convert_info += '[nn.Sequential] is converted to [nn.SequentialCell]\n' + self.convert_info += '[nn.init] is not converted and annotated\n' + self.convert_info += '[Convert over]' + + dest_file = os.path.join(output_dir, os.path.basename(srcfile)) + with os.fdopen(os.open(dest_file, self.flags, self.modes), 'w') as file: + file.write(code) + logger.info("Convert success. Result is wrote to %s.", dest_file) + + dest_report_file = os.path.join(report_dir, + '_'.join(os.path.basename(srcfile).split('.')[:-1]) + '_report.txt') + with os.fdopen(os.open(dest_report_file, self.flags, self.modes), 'a') as file: + file.write(self.convert_info) + logger.info("Convert report is saved in %s", dest_report_file) + + +def _get_name_ext(file): """ - Find api name from code with a start index i, check api name ok with a is_forward condition. - Args: - code (str): The code from which to find api name. - i (int): The start index to find. - is_forward (bool): Check if the found api name ok. + Split a file name in name and extension. - Returns: - str, api name if find api name and check ok with is_forward condition, else a null character string. - """ - if code[i:].startswith("nn.") \ - or code[i:].startswith("F.") \ - or code[i:].startswith("torch.") \ - or code[i:].startswith('.'): - j = code.find('(', i) - if j != -1 and code[i:j] in ALL_TORCH_APIS: - api_name = code[i:j] - if (not is_forward and api_name in NN_LIST) or (is_forward and api_name in ALL_2P_LIST): - return api_name - return "" - - -def convert_function(fun_name, fun, is_forward): - """ - Convert a PyTorch function into MindSpore function. Args: - fun_name (str): The str of function name. - fun (func): The function to convert. - is_forward (bool): If the function is defined in forward function in nn.Module in torch. + file (str): Full file path. Returns: - dict, old code and converted code map if convert happens, else {}. + tuple (str, str), name and extension. """ - _, line_no = inspect.getsourcelines(fun) - logger.info("Line %3d: start converting function %s()", line_no, fun_name) - - code = inspect.getsource(fun) - code_saved = copy.copy(code) - - i = 0 - while i < len(code): - api_name = find_api(code, i, is_forward) - if api_name: - line_no1 = line_no + code[:i].count('\n') - if api_name in ALL_MAPPING: - logger.info("Line %3d start converting API: %s", line_no1, api_name) - code, i = convert_api(code, i, api_name) - continue - warn_info = ". " + UNSUPPORTED_WARN_INFOS[api_name] if api_name in UNSUPPORTED_WARN_INFOS else "" - logger.warning("Line %3d: found unsupported API: %s%s", line_no1, api_name, warn_info) - i += 1 - return {code_saved: code} if code_saved != code else {} + _, name = os.path.split(file) + return os.path.splitext(name) -def judge_forward(name, forward_list): +def _path_split(file): """ - Check if function is a forward function. + Split a path in head and tail. Args: - name (str): The function name. - forward_list (set): A set of forward function. + file (str): The file path. Returns: - bool, True or False - """ - is_forward = name in forward_list or name.split(".")[-1] == "forward" - if is_forward: - logger.debug("%s is a forward function", name) - return is_forward + list[str], list of file tail - -def convert_module(module_name, module, forward_list): """ - Convert a PyTorch module code into MindSpore module code. - Args: - module_name (str): The module's name. - module (module): The module to convert. - forward_list (set): A set of forward function. + file_dir, name = os.path.split(file) + if file_dir: + sep = file[len(file_dir)-1] + if file_dir.startswith(sep): + return file.split(sep)[1:] - Returns: - dict, map of old code and converted code. - """ - _, line_no = inspect.getsourcelines(module) - logger.info("Line {:3d}: start converting nn.Module {}".format(line_no, module_name)) - - mapped = {} - for name, member in inspect.getmembers(module): - if is_valid_function(module, member): - is_forward = judge_forward("{}.{}".format(module_name, name), forward_list) - mapped.update(convert_function(name, member, is_forward)) - return mapped + return file.split(sep) + return [name] -def get_mapping(import_mod, forward_list): +def main(files_config): """ - Convert code of a module and get mapping of old code and convert code. - Args: - import_mod (module): The module to convert. - forward_list (set): A set of forward function. + The entrance for converter, script files will be converted. - Returns: - dict, mapping for old code and converted code of the module - """ - mapping = {} - tasks = [] - for name, member in inspect.getmembers(import_mod): - if is_valid_module(import_mod, member): - _, line_no = inspect.getsourcelines(member) - tasks.append((line_no, convert_module, (name, member, forward_list))) - elif is_valid_function(import_mod, member): - _, line_no = inspect.getsourcelines(member) - is_forward = judge_forward("{}.{}".format(import_mod, name), forward_list) - tasks.append((line_no, convert_function, (name, member, is_forward))) - tasks.sort() - for _, convert_fun, args in tasks: - mapping.update(convert_fun(*args)) - return mapping - - -def convert(import_name, nn_module): - """ - The entrance for convert a module's code, code converted will be write to file called out.py. Args: - import_name (str): The module from which to import the module to convert. - nn_module (str): Name of the module to convert. - + files_config (dict): The config of files which to convert. """ - logger.info("Start converting %s.%s", import_name, nn_module) - import_mod = importlib.import_module(import_name) - - forward_list = set() - - logger.debug("Forward_list: %s", forward_list) - - # replace python function under nn.Modlue - mapping = get_mapping(import_mod, forward_list) - - code = inspect.getsource(import_mod) - for key, value in mapping.items(): - code = code.replace(key, value) - - code = 'import mindspore.ops.operations as P\n' + code - code = 'import mindspore.nn as nn\n' + code - code = 'import mindspore\n' + code - code = code.replace('import torch', '# import torch') - code = code.replace('from torch', '# from torch') - code = code.replace('(nn.Module):', '(nn.Cell):') - code = code.replace('forward(', 'construct(') - code = code.replace('nn.Linear', 'nn.Dense') - code = code.replace('(nn.Sequential)', '(nn.SequentialCell)') - code = code.replace('nn.init.', 'pass # nn.init.') - - flags = os.O_WRONLY | os.O_CREAT | os.O_EXCL - modes = stat.S_IWUSR | stat.S_IRUSR - with os.fdopen(os.open('out.py', flags, modes), 'w') as file: - file.write(code) - logger.info("Convert success. Result is wrote to out.py\n") - - -if __name__ == '__main__': - - convert('torchvision.models.resnet', 'resnet18') + convert_ins = Converter() + root_path = files_config['root_path'] + in_files = files_config['in_files'] + for in_file in in_files: + in_file_split = _path_split(in_file[len(root_path):]) + in_file_split[-1], _ = _get_name_ext(in_file_split[-1]) + module_name = '.'.join(in_file_split) + convert_ins.convert(module_name, files_config['outfile_dir'], files_config['report_dir']) + + in_module = files_config['in_module'] + if in_module: + convert_ins.convert(in_module, files_config['outfile_dir'], files_config['report_dir']) diff --git a/mindinsight/mindconverter/forward_call.py b/mindinsight/mindconverter/forward_call.py new file mode 100644 index 0000000000000000000000000000000000000000..61e401331b6b8faf02c2befd98a4114ab4dec9cc --- /dev/null +++ b/mindinsight/mindconverter/forward_call.py @@ -0,0 +1,96 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Find out forward functions of script file""" +import ast +import os + + +class ForwardCall(ast.NodeVisitor): + """ + AST visitor that processes forward calls. + + Find the sub functions called by the forward function in the script file. + """ + + def __init__(self, filename): + self.filename = filename + self.module_name = os.path.basename(filename).replace('.py', '') + self.name_stack = [] + self.forward_stack = [] + self.calls = [] + self.process() + + def process(self): + """Parse the python source file to find the forward functions.""" + with open(self.filename, 'rt', encoding='utf-8') as file: + content = file.read() + self.visit(ast.parse(content, self.filename)) + + def get_current_namespace(self): + """Get the namespace when visit the AST node""" + namespace = '.'.join(self.name_stack) + return namespace + + @classmethod + def get_ast_node_name(cls, node): + """Get AST node name.""" + if isinstance(node, ast.Attribute): + return f'{cls.get_ast_node_name(node.value)}.{node.attr}' + + if isinstance(node, ast.Name): + return node.id + + return node + + def visit_ClassDef(self, node): + """Callback function when visit AST tree""" + self.name_stack.append(node.name) + self.generic_visit(node) + self.name_stack.pop() + + def visit_FunctionDef(self, node): + """Callback function when visit AST tree""" + func_name = f'{self.get_current_namespace()}.{node.name}' + is_in_chain = func_name in self.calls or node.name == 'forward' + if is_in_chain: + self.forward_stack.append(func_name) + + if node.name == 'forward': + self.calls.append(func_name) + + self.generic_visit(node) + + if is_in_chain: + self.forward_stack.pop() + + def visit_Call(self, node): + """Callback function when visit AST tree""" + for arg in node.args: + self.visit(arg) + for kw in node.keywords: + self.visit(kw.value) + func_name = self.get_ast_node_name(node.func) + if isinstance(node.func, ast.Name): + if func_name not in ['super', 'str', 'repr']: + if self.forward_stack: + self.calls.append(func_name) + self.visit(node.func) + else: + if self.forward_stack: + if 'self' in func_name: + self.calls.append(f'{self.get_current_namespace()}.{func_name.split(".")[-1]}') + else: + self.calls.append(func_name) + self.visit(node.func) diff --git a/setup.py b/setup.py index c54dad2b8325f0f38cf046caed3ff502b30a0491..b83b8e45c55a43068e4b7bfed95219c111dd0843 100644 --- a/setup.py +++ b/setup.py @@ -194,6 +194,7 @@ if __name__ == '__main__': entry_points={ 'console_scripts': [ 'mindinsight=mindinsight.utils.command:main', + 'mindconverter=mindinsight.mindconverter.cli:cli_entry', ], }, python_requires='>=3.7',