提交 3d7ccce8 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!146 mindconverter add cli and report

Merge pull request !146 from quyongxiu1/br_0522_dev_wzk_qyz
### Introduction
MindConverter is a tool that converting PyTorch scripts to MindSpore scripts. With minial manual editing and the guidance from conversion reports, users may easily migrate their model from PyTorch framework to MindSpore.
### System Requirements
* PyTorch v1.5.0
* MindSpore v0.2.0
### Installation
This tool is part of MindInsight and accessible to users after installing MindInsight, no extra installation is needed.
### Commandline Usage
Set the model scripts directory as the PYTHONPATH environment variable first:
```buildoutcfg
export PYTHONPATH=<model scripts dir>
```
mindconverter commandline usage:
```buildoutcfg
mindconverter [-h] [--version] --in_file IN_FILE [--output OUTPUT]
[--report REPORT]
MindConverter CLI entry point (version: 0.2.0)
optional arguments:
-h, --help show this help message and exit
--version show program's version number and exit
--in_file IN_FILE Specify path for script file.
--output OUTPUT Specify path for converted script file directory. Default
is output directory in the current working directory.
--report REPORT Specify report directory. Default is the current working
directorys
```
Usage example:
```buildoutcfg
export PYTHONPATH=~/my_pt_proj/models
mindconverter --in_file lenet.py
```
Since the conversion is not 100% flawless, we encourage users to checkout the reports when fixing issues of the converted scripts.
### Unsupported Situation #1
Classes and functions that can't be converted:
* The use of shape, ndim and dtype member of torch.Tensor.
* torch.nn.AdaptiveXXXPoolXd and torch.nn.functional.adaptive_XXX_poolXd()
* torch.nn.functional.Dropout
* torch.unsqueeze() and torch.Tensor.unsqueeze()
* torch.chunk() and torch.Tensor.chunk()
### Unsupported Situation #2
Subclassing from the subclasses of nn.Module
e.g. (code snip from torchvision,models.mobilenet)
```python
from torch import nn
class ConvBNReLU(nn.Sequential):
def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1):
padding = (kernel_size - 1) // 2
super(ConvBNReLU, self).__init__(
nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False),
nn.BatchNorm2d(out_planes),
nn.ReLU6(inplace=True)
)
```
### Unsupported Situation #3
Unconventional import naming
e.g.
```python
import torch.nn as mm
```
\ No newline at end of file
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Command module."""
import os
import sys
import argparse
import mindinsight
from mindinsight.mindconverter.converter import main
class FileDirAction(argparse.Action):
"""File directory action class definition."""
@staticmethod
def check_path(parser, values, option_string=None):
"""
Check argument for file path.
Args:
parser (ArgumentParser): Passed-in argument parser.
values (object): Argument values with type depending on argument definition.
option_string (str): Optional string for specific argument name. Default: None.
"""
outfile = values
if outfile.startswith('~'):
outfile = os.path.realpath(os.path.expanduser(outfile))
if not outfile.startswith('/'):
outfile = os.path.realpath(os.path.join(os.getcwd(), outfile))
if os.path.exists(outfile) and not os.access(outfile, os.R_OK):
parser.error(f'{option_string} {outfile} not accessible')
return outfile
def __call__(self, parser, namespace, values, option_string=None):
"""
Inherited __call__ method from argparse.Action.
Args:
parser (ArgumentParser): Passed-in argument parser.
namespace (Namespace): Namespace object to hold arguments.
values (object): Argument values with type depending on argument definition.
option_string (str): Optional string for specific argument name. Default: None.
"""
outfile_dir = self.check_path(parser, values, option_string)
if os.path.isfile(outfile_dir):
parser.error(f'{option_string} {outfile_dir} is a file')
setattr(namespace, self.dest, outfile_dir)
class OutputDirAction(argparse.Action):
"""File directory action class definition."""
def __call__(self, parser, namespace, values, option_string=None):
"""
Inherited __call__ method from argparse.Action.
Args:
parser (ArgumentParser): Passed-in argument parser.
namespace (Namespace): Namespace object to hold arguments.
values (object): Argument values with type depending on argument definition.
option_string (str): Optional string for specific argument name. Default: None.
"""
output = values
if output.startswith('~'):
output = os.path.realpath(os.path.expanduser(output))
if not output.startswith('/'):
output = os.path.realpath(os.path.join(os.getcwd(), output))
if os.path.exists(output):
if not os.access(output, os.R_OK):
parser.error(f'{option_string} {output} not accessible')
if os.path.isfile(output):
parser.error(f'{option_string} {output} is a file')
setattr(namespace, self.dest, output)
class InFileAction(argparse.Action):
"""Input File action class definition."""
def __call__(self, parser, namespace, values, option_string=None):
"""
Inherited __call__ method from argparse.Action.
Args:
parser (ArgumentParser): Passed-in argument parser.
namespace (Namespace): Namespace object to hold arguments.
values (object): Argument values with type depending on argument definition.
option_string (str): Optional string for specific argument name. Default: None.
"""
outfile_dir = FileDirAction.check_path(parser, values, option_string)
if not os.path.exists(outfile_dir):
parser.error(f'{option_string} {outfile_dir} not exists')
if not os.path.isfile(outfile_dir):
parser.error(f'{option_string} {outfile_dir} is not a file')
setattr(namespace, self.dest, outfile_dir)
class LogFileAction(argparse.Action):
"""Log file action class definition."""
def __call__(self, parser, namespace, values, option_string=None):
"""
Inherited __call__ method from FileDirAction.
Args:
parser (ArgumentParser): Passed-in argument parser.
namespace (Namespace): Namespace object to hold arguments.
values (object): Argument values with type depending on argument definition.
option_string (str): Optional string for specific argument name. Default: None.
"""
outfile_dir = FileDirAction.check_path(parser, values, option_string)
if os.path.exists(outfile_dir) and not os.path.isdir(outfile_dir):
parser.error(f'{option_string} {outfile_dir} is not a directory')
setattr(namespace, self.dest, outfile_dir)
def cli_entry():
"""Entry point for mindconverter CLI."""
permissions = os.R_OK | os.W_OK | os.X_OK
os.umask(permissions << 3 | permissions)
parser = argparse.ArgumentParser(
prog='mindconverter',
description='MindConverter CLI entry point (version: {})'.format(mindinsight.__version__))
parser.add_argument(
'--version',
action='version',
version='%(prog)s ({})'.format(mindinsight.__version__))
parser.add_argument(
'--in_file',
type=str,
action=InFileAction,
required=True,
help="""
Specify path for script file.
""")
parser.add_argument(
'--output',
type=str,
action=OutputDirAction,
default=os.path.join(os.getcwd(), 'output'),
help="""
Specify path for converted script file directory.
Default is output directory in the current working directory.
""")
parser.add_argument(
'--report',
type=str,
action=LogFileAction,
default=os.getcwd(),
help="""
Specify report directory. Default is the current working directory.
""")
argv = sys.argv[1:]
if not argv:
argv = ['-h']
args = parser.parse_args(argv)
else:
args = parser.parse_args()
mode = permissions << 6
os.makedirs(args.output, mode=mode, exist_ok=True)
os.makedirs(args.report, mode=mode, exist_ok=True)
_run(args.in_file, args.output, '', args.report)
def _run(in_files, out_dir, in_module, report):
"""
Run converter command.
Args:
in_files (str): The file path or directory to convert.
out_dir (str): The output directory to save converted file.
in_module (str): The module name to convert.
report (str): The report file path.
"""
files_config = {
'root_path': in_files if in_files else '',
'in_files': [],
'outfile_dir': out_dir,
'report_dir': report,
'in_module': in_module
}
if os.path.isfile(in_files):
files_config['root_path'] = os.path.dirname(in_files)
files_config['in_files'] = [in_files]
else:
for root_dir, _, files in os.walk(in_files):
for file in files:
files_config['in_files'].append(os.path.join(root_dir, file))
main(files_config)
...@@ -15,4 +15,4 @@ ...@@ -15,4 +15,4 @@
"""Create a logger.""" """Create a logger."""
from mindinsight.utils.log import setup_logger from mindinsight.utils.log import setup_logger
logger = setup_logger("mindconverter", "mindconverter") logger = setup_logger("mindconverter", "mindconverter", console=False)
...@@ -67,7 +67,7 @@ class APIPt: ...@@ -67,7 +67,7 @@ class APIPt:
Raises: Raises:
ValueError: If can not use ast to parse or the required parse node not type of ast.Call, ValueError: If can not use ast to parse or the required parse node not type of ast.Call,
or the given args_str not valid. or the given args_str not valid.
""" """
# expr is REQUIRED to meet (**) format # expr is REQUIRED to meet (**) format
if not (len(args_str) >= 2 and args_str[0] == "(" and args_str[-1] == ")"): if not (len(args_str) >= 2 and args_str[0] == "(" and args_str[-1] == ")"):
...@@ -222,6 +222,7 @@ class MappingHelper: ...@@ -222,6 +222,7 @@ class MappingHelper:
def convert(self, call_name_pt: str, args_str_pt: str): def convert(self, call_name_pt: str, args_str_pt: str):
""" """
Convert code sentence to MindSpore code sentence. Convert code sentence to MindSpore code sentence.
Args: Args:
call_name_pt (str): str of the call function, etc. call_name_pt (str): str of the call function, etc.
args_str_pt (str): str of args for function, which starts with '(' and end with ')'. args_str_pt (str): str of args for function, which starts with '(' and end with ')'.
...@@ -336,11 +337,12 @@ def gen_explicit_map_one_delta(params_pt, args_pt, k_ms, k_pt): ...@@ -336,11 +337,12 @@ def gen_explicit_map_one_delta(params_pt, args_pt, k_ms, k_pt):
def load_json_file(file_path): def load_json_file(file_path):
""" """
Load data from given json file path. Load data from given json file path.
Args: Args:
file_path (str): The file to load json data from. file_path (str): The file to load json data from.
Returns: Returns:
list, the list data stored in file_path. list(str), the list data stored in file_path.
""" """
with open(file_path, 'r', encoding='utf-8') as file: with open(file_path, 'r', encoding='utf-8') as file:
info = json.loads(file.read()) info = json.loads(file.read())
......
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Find out forward functions of script file"""
import ast
import os
class ForwardCall(ast.NodeVisitor):
"""
AST visitor that processes forward calls.
Find the sub functions called by the forward function in the script file.
"""
def __init__(self, filename):
self.filename = filename
self.module_name = os.path.basename(filename).replace('.py', '')
self.name_stack = []
self.forward_stack = []
self.calls = []
self.process()
def process(self):
"""Parse the python source file to find the forward functions."""
with open(self.filename, 'rt', encoding='utf-8') as file:
content = file.read()
self.visit(ast.parse(content, self.filename))
def get_current_namespace(self):
"""Get the namespace when visit the AST node"""
namespace = '.'.join(self.name_stack)
return namespace
@classmethod
def get_ast_node_name(cls, node):
"""Get AST node name."""
if isinstance(node, ast.Attribute):
return f'{cls.get_ast_node_name(node.value)}.{node.attr}'
if isinstance(node, ast.Name):
return node.id
return node
def visit_ClassDef(self, node):
"""Callback function when visit AST tree"""
self.name_stack.append(node.name)
self.generic_visit(node)
self.name_stack.pop()
def visit_FunctionDef(self, node):
"""Callback function when visit AST tree"""
func_name = f'{self.get_current_namespace()}.{node.name}'
is_in_chain = func_name in self.calls or node.name == 'forward'
if is_in_chain:
self.forward_stack.append(func_name)
if node.name == 'forward':
self.calls.append(func_name)
self.generic_visit(node)
if is_in_chain:
self.forward_stack.pop()
def visit_Call(self, node):
"""Callback function when visit AST tree"""
for arg in node.args:
self.visit(arg)
for kw in node.keywords:
self.visit(kw.value)
func_name = self.get_ast_node_name(node.func)
if isinstance(node.func, ast.Name):
if func_name not in ['super', 'str', 'repr']:
if self.forward_stack:
self.calls.append(func_name)
self.visit(node.func)
else:
if self.forward_stack:
if 'self' in func_name:
self.calls.append(f'{self.get_current_namespace()}.{func_name.split(".")[-1]}')
else:
self.calls.append(func_name)
self.visit(node.func)
...@@ -194,6 +194,7 @@ if __name__ == '__main__': ...@@ -194,6 +194,7 @@ if __name__ == '__main__':
entry_points={ entry_points={
'console_scripts': [ 'console_scripts': [
'mindinsight=mindinsight.utils.command:main', 'mindinsight=mindinsight.utils.command:main',
'mindconverter=mindinsight.mindconverter.cli:cli_entry',
], ],
}, },
python_requires='>=3.7', python_requires='>=3.7',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册