提交 7cad801d 编写于 作者: G ggpolar

parse scripts by AST in converter module.

Use the AST replaces the importlib/inspect modules to analyze and modify network definition script.
The importlib/inspect must load python script to analyze, but AST analysis is static code analysis and is very secure.
上级 ce93cb82
此差异已折叠。
......@@ -186,25 +186,23 @@ def cli_entry():
mode = permissions << 6
os.makedirs(args.output, mode=mode, exist_ok=True)
os.makedirs(args.report, mode=mode, exist_ok=True)
_run(args.in_file, args.output, '', args.report)
_run(args.in_file, args.output, args.report)
def _run(in_files, out_dir, in_module, report):
def _run(in_files, out_dir, report):
"""
Run converter command.
Args:
in_files (str): The file path or directory to convert.
out_dir (str): The output directory to save converted file.
in_module (str): The module name to convert.
report (str): The report file path.
"""
files_config = {
'root_path': in_files if in_files else '',
'in_files': [],
'outfile_dir': out_dir,
'report_dir': report,
'in_module': in_module
'report_dir': report
}
if os.path.isfile(in_files):
files_config['root_path'] = os.path.dirname(in_files)
......
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless REQUIRED by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""code analysis module"""
import ast
import pasta
from pasta.base import scope
from mindinsight.mindconverter.common.exceptions import ScriptNotSupport
class APIAnalysisSpec:
"""API analysis specifications"""
import_name_mapping = {'torch': ['mindspore', None],
'torch.nn': ['mindspore.nn', 'nn'],
'torch.nn.functional': ['mindspore.ops.operations', 'P']}
base_name_mapping = {'Module': 'Cell',
'Sequential': 'SequentialCell'
}
@classmethod
def get_convertible_external_names(cls):
"""
Obtain the convertible external names.
The external name is the full dotted name being referenced.
"""
return cls.import_name_mapping.keys()
@staticmethod
def get_network_base_class_names():
"""Obtain the base names which network class base from"""
return ['Module',
'Sequential',
'ModuleList',
'ModuleDict',
'ParameterList',
'ParameterDict']
@staticmethod
def check_external_alias_ref(ref_name, external_name):
"""
Check 'import as' is standard.
Standard references are follow:
import torch.nn as nn
import torch.nn.functional as F
Args:
ref_name (str): The name that refers to the external_name.
external_name (str): The full dotted name being referenced. For examples:
1. 'import torch.nn as nn', torch.nn is external_name, nn is ref_name.
2. 'from torch import nn as mm, torch.nn is external_name, mm is ref_name which is not a standard name.
Returns:
boolean, True if ref_name is standard else False.
"""
if ref_name != 'nn' and external_name == 'torch.nn':
is_standard = False
elif ref_name != 'F' and external_name == 'torch.nn.functional':
is_standard = False
else:
is_standard = True
return is_standard
class CodeAnalyzer(ast.NodeVisitor):
"""Code analyzer that analyzes PyTorch python script by AST Visitor.
CodeAnalyzer find the codes that need to be converted to MindSpore,
and provides the attributes related to the codes.
"""
def __init__(self):
self._stack = [] # Used to easily access the parent node
self._external_references = {}
self._is_standard_external_ref = True
self._root_scope = None
# Used to save functions that need to be converted, value type is pasta.base.scope.Scope
self._network_functions = []
# Used to easily trace the function node
self._functions_stack = []
# key type is pasta.base.scope.Scope, value type is list
self._network_classes = {}
@property
def root_scope(self):
"""The root scope of the python script code."""
return self._root_scope
@property
def is_standard_external_ref(self):
"""Obtain whether the result is a standard external reference."""
return self._is_standard_external_ref
@property
def external_references(self):
"""Obtain all external references in the analyzed code."""
return self._external_references
def network_definitions(self):
"""Obtain the network definitions which need to be converted."""
return {"functions": self._network_functions,
"cell": self._network_classes}
def process(self, ast_tree):
"""
Start to analyze the code.
Args:
ast_tree (AST): The root node of the source code.
"""
self.__init__()
self._root_scope = scope.analyze(ast_tree)
self._pre_process()
self.visit(ast_tree)
if not self._network_classes:
msg = "model definition not be found."
raise ScriptNotSupport(msg)
@staticmethod
def _check_external_standard(external_refs):
"""Check whether all external references are standard."""
is_standard = True
for external_name, external_ref_info in external_refs.items():
is_standard = APIAnalysisSpec.check_external_alias_ref(external_name, external_ref_info.name)
if not is_standard:
break
return is_standard
def _is_base_from_cell(self, node):
"""
Check whether the node bases from cell classes which are defined in APIAnalysisSpec.
Args:
node (ast.ClassDef): The node which is a class definition.
Returns:
boolean, True if the check result is Passed else False.
"""
if self._is_ref_convertible_imports(node):
whole_name = self._get_whole_name(node)
if whole_name.split('.')[-1] in APIAnalysisSpec.get_network_base_class_names():
return True
return False
def _pre_process(self):
"""Preprocessor checks the code before analyzing."""
is_torch = False
# check whether the code imports torch.
for ref_name in self._root_scope.external_references.keys():
if ref_name.split('.')[0] in APIAnalysisSpec.get_convertible_external_names():
is_torch = True
break
if not is_torch:
msg = "The source code does not import torch, model definition can not be found."
raise ScriptNotSupport(msg)
# Find out external reference in the code and save it.
external_refs = self._analyze_import_references(self._root_scope)
self._is_standard_external_ref = self._check_external_standard(external_refs)
self._check_external_standard(external_refs)
for external_name, external_ref_info in external_refs.items():
self._external_references.update({
external_name: {
'external_ref_info': external_ref_info,
'parent_node': None
}
})
@staticmethod
def _analyze_import_references(root_scope):
"""Find out all references from the import statements."""
external_name_ref = {}
for node_references in root_scope.external_references.values():
for node_ref in node_references:
if node_ref.name_ref:
# (from)import alias, node_ref.name_ref.id is alias name
if node_ref.name_ref.definition.asname == node_ref.name_ref.id:
external_name_ref[node_ref.name_ref.id] = node_ref
# import without alias, node_ref.name_ref.definition.asname is None.
# e.g., import a.b.c, reference maybe is a, a.b or a.b.c in the root_scope.external_references.
# The reference a.b.c is really wanted.
elif node_ref.name_ref.definition.name == node_ref.name_ref.id:
external_name_ref[node_ref.name_ref.id] = node_ref
else:
pass
return external_name_ref
def visit(self, node):
"""Overridden visit of the base class to maintain stack information to access parent node."""
self._stack.append(node)
super(CodeAnalyzer, self).visit(node)
self._stack.pop()
@staticmethod
def _get_full_name(node):
"""Get the full name of the node."""
if not isinstance(node, (ast.Attribute, ast.Name)):
return None
return pasta.dump(node)
def _get_whole_name(self, node):
"""
Get the whole name of the node.
For example, nn.Module is spliced two nodes, nn node and Module node.
When visit ast nodes,
Module node is first visited, the full name is the same as the whole name, that is nn.Module.
And then nn node is visited, the full name is nn, the whole name is nn.Module.
"""
full_name = self._get_full_name(node)
if not full_name:
return None
# node is in stack top pos
if node is self._stack[-1]:
parent_index = -1
while isinstance(self._stack[parent_index], ast.Attribute):
parent_index -= 1
whole_name = self._get_full_name(self._stack[parent_index])
else:
whole_name = full_name
return whole_name
def _is_ref_convertible_imports(self, node):
"""Check whether the node references convertible imports."""
check_result = False
whole_name = self._get_whole_name(node)
if whole_name:
module_name = whole_name.split('.')[0]
for ref_name, ref_info in self._external_references.items():
external_ref = ref_info['external_ref_info']
# external reference is convertible module
if external_ref.name in APIAnalysisSpec.get_convertible_external_names():
# import from the same external module
if module_name == ref_name.split('.')[0]:
check_result = True
break
return check_result
@staticmethod
def _get_external_node(external_references):
"""Get all external reference nodes."""
external_nodes = {}
for ref_name, ref_info in external_references.items():
external_nodes.update({ref_info['external_ref_info'].node: ref_name})
return external_nodes
@staticmethod
def _get_convertible_external_node(external_name_ref):
"""Get all convertible external reference nodes."""
convertible_external_nodes = {}
for ref_name, ref_info in external_name_ref.items():
if ref_info['external_ref_info'].name in APIAnalysisSpec.get_convertible_external_names():
convertible_external_nodes.update({ref_info['external_ref_info'].node: ref_name})
return convertible_external_nodes
def _update_external_ref_parent(self, node):
"""Set external reference parent node info."""
external_nodes = self._get_external_node(self._external_references)
convertible_external_nodes = self._get_convertible_external_node(self._external_references)
for name_node in node.names:
if name_node in convertible_external_nodes.keys():
if len(node.names) > 1:
msg = """\
Not support multiple imports of torch on one line in your script. line:%s: %s
""" % (node.lineno, pasta.dump(node))
raise ScriptNotSupport(msg)
if name_node in external_nodes.keys():
ref_name = external_nodes[name_node]
self._external_references[ref_name]['parent_node'] = node
@staticmethod
def _get_class_scope(node_scope):
"""Find the class scope of the node_scope."""
parent_scope = node_scope.parent_scope
class_scope = None
while parent_scope:
if isinstance(parent_scope.node, ast.ClassDef):
class_scope = parent_scope
break
parent_scope = parent_scope.parent_scope
return class_scope
def _update_convertible_functions(self, node):
"""Update convertible functions."""
node_scope = self._root_scope.lookup_scope(node)
class_scope = self._get_class_scope(node_scope)
if class_scope:
network_classes = self._network_classes.get(class_scope, [])
if node_scope not in network_classes:
network_classes.append(node_scope)
else:
if node_scope not in self._network_functions:
self._network_functions.append(node_scope)
def visit_ClassDef(self, node):
"""Callback function when visit AST tree"""
if not self._stack[-1] is node:
return
for base in node.bases:
if self._is_ref_convertible_imports(base):
self._network_classes[self._root_scope.lookup_scope(node)] = []
self.generic_visit(node)
def visit_Import(self, node):
"""Callback function when visit AST tree"""
self._update_external_ref_parent(node)
self.generic_visit(node)
def visit_ImportFrom(self, node):
"""Callback function when visit AST tree"""
self._update_external_ref_parent(node)
self.generic_visit(node)
def visit_Call(self, node):
"""Callback function when visit AST tree"""
if not self._stack[-1] is node:
return
is_in_network_function = False
# If torch call is happened in the function, save the function for network definition.
if self._functions_stack and self._is_ref_convertible_imports(node.func):
self._update_convertible_functions(self._functions_stack[-1])
is_in_network_function = True
if not is_in_network_function:
self.generic_visit(node)
def visit_FunctionDef(self, node):
"""Callback function when visit AST tree"""
if not self._stack[-1] is node:
return
if node.name == "forward":
self._update_convertible_functions(node)
self._functions_stack.append(node)
self.generic_visit(node)
self._functions_stack.pop()
def get_name(self, node):
"""
Get the node name.
Args:
node (AST): The ast node of the source code.
Returns:
str, the name of the node
"""
if isinstance(node, pasta.base.scope.Scope):
items = [self.get_name(node.node)]
parent_scope = node.parent_scope
while parent_scope:
if not isinstance(parent_scope.node, ast.Module):
items.append(self.get_name(parent_scope.node))
parent_scope = parent_scope.parent_scope
return '.'.join(reversed(items))
if isinstance(node, (ast.ClassDef, ast.FunctionDef)):
return node.name
if isinstance(node, (ast.Name, ast.Attribute)):
return self._get_full_name(node)
return str(node)
def lookup_scope(self, node):
"""
Search the scope of the node.
Args:
node (AST): The ast node of the source code.
Returns:
scope, the scope of the node
"""
if isinstance(node, pasta.base.scope.Scope):
return node
return self._root_scope.lookup_scope(node)
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Define custom exception."""
from enum import unique
from mindinsight.utils.constant import ScriptConverterErrors
from mindinsight.utils.exceptions import MindInsightException
@unique
class ConverterErrors(ScriptConverterErrors):
"""Converter error codes."""
SCRIPT_NOT_SUPPORT = 1
NODE_TYPE_NOT_SUPPORT = 2
class ScriptNotSupport(MindInsightException):
"""The script can not support to process."""
def __init__(self, msg):
super(ScriptNotSupport, self).__init__(ConverterErrors.SCRIPT_NOT_SUPPORT,
msg,
http_code=400)
class NodeTypeNotSupport(MindInsightException):
"""The astNode can not support to process."""
def __init__(self, msg):
super(NodeTypeNotSupport, self).__init__(ConverterErrors.NODE_TYPE_NOT_SUPPORT,
msg,
http_code=400)
......@@ -14,7 +14,8 @@
# ============================================================================
"""Find out forward functions of script file"""
import ast
import os
import pasta
class ForwardCall(ast.NodeVisitor):
......@@ -24,73 +25,80 @@ class ForwardCall(ast.NodeVisitor):
Find the sub functions called by the forward function in the script file.
"""
def __init__(self, filename):
self.filename = filename
self.module_name = os.path.basename(filename).replace('.py', '')
self.name_stack = []
self.forward_stack = []
self.calls = set()
def __init__(self, ast_tree):
self._tree = ast_tree
self._name_stack = []
self._forward_stack = []
self.calls = {} # key is function name, value is forward function ast node.
self._function_list = {} # key is function name, value is function ast node.
self.process()
def process(self):
"""Parse the python source file to find the forward functions."""
with open(self.filename, 'rt', encoding='utf-8') as file:
content = file.read()
self.visit(ast.parse(content, self.filename))
"""visit ast tree to find the forward functions."""
self.visit(self._tree)
# first visit to find out all functions, so restores all variables except _function_list
self._name_stack.clear()
self._forward_stack.clear()
self.calls.clear()
self.visit(self._tree)
def get_current_namespace(self):
"""Get the namespace when visit the AST node"""
namespace = '.'.join(self.name_stack)
namespace = '.'.join(self._name_stack)
return namespace
@classmethod
def get_ast_node_name(cls, node):
"""Get AST node name."""
if isinstance(node, ast.Attribute):
return f'{cls.get_ast_node_name(node.value)}.{node.attr}'
if isinstance(node, ast.Name):
return node.id
def get_call_name(cls, node):
"""Get functional call name."""
if not isinstance(node, ast.Call):
return None
return node
return pasta.dump(node.func)
def visit_ClassDef(self, node):
"""Callback function when visit AST tree"""
self.name_stack.append(node.name)
self._name_stack.append(node.name)
self.generic_visit(node)
self.name_stack.pop()
self._name_stack.pop()
def visit_FunctionDef(self, node):
"""Callback function when visit AST tree"""
namespace = self.get_current_namespace()
if namespace:
func_name = f'{namespace}.{node.name}'
else:
func_name = node.name
func_name = f'{self.get_current_namespace()}.{node.name}'
is_in_chain = func_name in self.calls or node.name == 'forward'
if is_in_chain:
self.forward_stack.append(func_name)
self._forward_stack.append(func_name)
if node.name == 'forward':
self.calls.add(func_name)
self.calls.update({func_name: node})
self._function_list.update({func_name: node})
self.generic_visit(node)
if is_in_chain:
self.forward_stack.pop()
self._forward_stack.pop()
def visit_Call(self, node):
"""Callback function when visit AST tree"""
for arg in node.args:
self.visit(arg)
for kw in node.keywords:
self.visit(kw.value)
func_name = self.get_ast_node_name(node.func)
for keyword in node.keywords:
self.visit(keyword.value)
func_name = self.get_call_name(node)
if isinstance(node.func, ast.Name):
if func_name not in ['super', 'str', 'repr']:
if self.forward_stack:
self.calls.add(func_name)
if self._forward_stack:
self.calls.update({func_name: self._function_list.get(func_name)})
self.visit(node.func)
else:
if self.forward_stack:
if 'self' in func_name:
self.calls.add(f'{self.get_current_namespace()}.{func_name.split(".")[-1]}')
if self._forward_stack:
if func_name.startswith('self.'):
whole_name = f'{self.get_current_namespace()}.{func_name.split(".")[-1]}'
self.calls.update({whole_name: self._function_list.get(whole_name)})
else:
self.calls.add(func_name)
self.calls.update({func_name: self._function_list.get(func_name)})
self.visit(node.func)
......@@ -30,6 +30,7 @@ class MindInsightModules(Enum):
LINEAGEMGR = 2
DATAVISUAL = 5
PROFILERMGR = 6
SCRIPTCONVERTER = 7
class GeneralErrors(Enum):
......@@ -69,3 +70,7 @@ class DataVisualErrors(Enum):
SCALAR_NOT_EXIST = 14
HISTOGRAM_NOT_EXIST = 15
TRAIN_JOB_DETAIL_NOT_IN_CACHE = 16
class ScriptConverterErrors(Enum):
"""Enum definition for mindconverter errors."""
......@@ -22,380 +22,201 @@ class TestConverter:
converter_ins = Converter()
def test_judge_forward(self):
"""test judge_forward"""
name1 = 'conv1'
forward_list = {'conv1', 'relu'}
result1 = self.converter_ins.judge_forward(name1, forward_list)
assert result1 is True
name2 = 'self.forward'
result2 = self.converter_ins.judge_forward(name2, forward_list)
assert result2 is True
def test_find_left_parentheses(self):
"""test find_left_parentheses"""
code = '''nn.Sequential(nn.Conv2d(in_dim, 6, 5, stride=1, padding=0, ),
nn.ReLU(),
nn.ReLU(True),
nn.MaxPool2d(2, 2),
nn.Conv2d(6, 16, 5, stride=1, padding=0),
nn.ReLU(inplace=False),
nn.MaxPool2d(2, 2))'''
right_index = len(code) - 1
left_index = code.index('nn.Conv2d')
result = self.converter_ins.find_left_parentheses(code, right_index)
assert result == left_index - 1
def test_find_api(self):
"""test find_api"""
code = '''nn.Sequential(nn.Conv2d(in_dim, 6, 5, stride=1, padding=0, ),
nn.ReLU(),
nn.ReLU(True),
nn.MaxPool2d(2, 2), # TODO padding
nn.Conv2d(6, 16, 5, stride=1, padding=0),
nn.ReLU(inplace=False),
nn.MaxPool2d(2, 2))'''
index = 0
is_forward = False
result = self.converter_ins.find_api(code, index, is_forward)
assert result == 'nn.Sequential'
def test_get_call_name(self):
"""test get_call_name"""
code = '''nn.Sequential(nn.Conv2d(in_dim, 6, 5, stride=1, padding=0))'''
end = len(code)
call_name, index = self.converter_ins.get_call_name(code, end)
assert call_name == ''
assert index == -1
def test_find_right_parentheses(self):
"""test find_right_parentheses"""
code = '''nn.Sequential(nn.Conv2d(in_dim, 6, 5, stride=1, padding=0, ),
nn.ReLU(),
nn.ReLU(True),
nn.MaxPool2d(2, 2), # TODO padding
nn.Conv2d(6, 16, 5, stride=1, padding=0),
nn.ReLU(inplace=False),
nn.MaxPool2d(2, 2))'''
left_index = 0
result = self.converter_ins.find_right_parentheses(code, left_index)
assert_index = len(code) - 1
assert result == assert_index
# test convert_api with nn ops
def test_convert_api_nn_layernorm(self):
"""Test convert_api function work ok when convert api nn.LayerNorm"""
code = """
def __init__(self, num_classes=1000):
self.features = nn.SequentialCell([
nn.LayerNorm((5, 10, 10), elementwise_affine=False),
nn.ReLU(inplace=False)
])
"""
code = "nn.LayerNorm((5, 10, 10), elementwise_affine=False)"
api_name = 'nn.LayerNorm'
start = code.find(api_name)
layer_norm_info = NN_MAPPING.get(api_name)
expected_ms_api_name = 'nn.LayerNorm'
epsilon = layer_norm_info.pt_api.params.get('eps')
replaced_code, new_start = self.converter_ins.convert_api(code, start, api_name)
replaced_code = self.converter_ins.convert_api(code)
assert replaced_code == code.replace('nn.LayerNorm((5, 10, 10), elementwise_affine=False)',
'{}(normalized_shape=(5, 10, 10), epsilon={})'.format(
expected_ms_api_name, epsilon))
assert new_start == start + len(expected_ms_api_name)
def test_convert_api_nn_leaky_relu(self):
"""Test convert_api function work ok when convert api nn.LeakyReLU"""
code = """
def __init__(self, num_classes=1000):
self.features = nn.SequentialCell([
nn.LayerNorm((5, 10, 10), elementwise_affine=False),
nn.LeakyReLU(0.3)])
"""
api_name = 'nn.LeakyReLU'
start = code.find(api_name)
code = "nn.LeakyReLU(0.3)"
expected_ms_api_name = 'nn.LeakyReLU'
replaced_code, new_start = self.converter_ins.convert_api(code, start, api_name)
replaced_code = self.converter_ins.convert_api(code)
assert replaced_code == code.replace('nn.LeakyReLU(0.3)',
'{}(alpha=0.3)'.format(expected_ms_api_name))
assert new_start == start + len(expected_ms_api_name)
def test_convert_api_nn_prelu(self):
"""Test convert_api function work ok when convert api nn.PReLU"""
code = """
input = torch.randn(2, 3, 5)
nn.PReLU()(input)
"""
api_name = 'nn.PReLU'
start = code.find(api_name)
code = "nn.PReLU()(input)"
expected_ms_api_name = 'nn.PReLU'
replaced_code, new_start = self.converter_ins.convert_api(code, start, api_name)
replaced_code = self.converter_ins.convert_api(code)
assert replaced_code == code.replace('nn.PReLU()(input)',
'{}()(input)'.format(expected_ms_api_name))
assert new_start == start + len(expected_ms_api_name)
def test_convert_api_nn_softmax(self):
"""Test convert_api function work ok when convert api nn.Softmax"""
code = """
nn.Softmax(dim=1)(input)
"""
api_name = 'nn.Softmax'
code = "nn.Softmax(dim=1)"
expected_ms_api_name = 'nn.Softmax'
start = code.find(api_name)
replaced_code, new_start = self.converter_ins.convert_api(code, start, api_name)
assert replaced_code == code.replace('nn.Softmax(dim=1)(input)',
'{}(axis=1)(input)'.format(expected_ms_api_name))
assert new_start == start + len(expected_ms_api_name)
replaced_code = self.converter_ins.convert_api(code)
assert replaced_code == code.replace('nn.Softmax(dim=1)',
'{}(axis=1)'.format(expected_ms_api_name))
# test convert_api with torch dot ops
def test_convert_api_torch_dot_abs(self):
"""Test convert_api function work ok when convert api torch.abs"""
code = """
torch.abs(input)
"""
api_name = 'torch.abs'
start = code.find(api_name)
code = "torch.abs(input)"
expected_ms_api_name = 'P.Abs'
replaced_code, new_start = self.converter_ins.convert_api(code, start, api_name)
replaced_code = self.converter_ins.convert_api(code)
assert replaced_code == code.replace('torch.abs(input)',
'{}()(input)'.format(expected_ms_api_name))
assert new_start == start + len(expected_ms_api_name)
def test_convert_api_torch_dot_acos(self):
"""Test convert_api function work ok when convert api torch.acos"""
code = """
torch.acos(input)
"""
api_name = 'torch.acos'
start = code.find(api_name)
code = "torch.acos(input)"
expected_ms_api_name = 'P.ACos'
replaced_code, new_start = self.converter_ins.convert_api(code, start, api_name)
replaced_code = self.converter_ins.convert_api(code)
assert replaced_code == code.replace('torch.acos(input)',
'{}()(input)'.format(expected_ms_api_name))
assert new_start == start + len(expected_ms_api_name)
def test_convert_api_torch_dot_cos(self):
"""Test convert_api function work ok when convert api torch.cos"""
code = """
torch.cos(input)
"""
api_name = 'torch.cos'
code = "torch.cos(input)"
expected_ms_api_name = 'P.Cos'
start = code.find(api_name)
replaced_code, new_start = self.converter_ins.convert_api(code, start, api_name)
replaced_code = self.converter_ins.convert_api(code)
assert replaced_code == code.replace('torch.cos(input)',
'{}()(input)'.format(expected_ms_api_name))
assert new_start == start + len(expected_ms_api_name)
def test_convert_api_torch_dot_exp(self):
"""Test convert_api function work ok when convert api torch.exp"""
code = """
torch.exp(input)
"""
api_name = 'torch.exp'
code = "torch.exp(input)"
expected_ms_api_name = 'P.Exp'
start = code.find(api_name)
replaced_code, new_start = self.converter_ins.convert_api(code, start, api_name)
replaced_code = self.converter_ins.convert_api(code)
assert replaced_code == code.replace('torch.exp(input)',
'{}()(input)'.format(expected_ms_api_name))
assert new_start == start + len(expected_ms_api_name)
def test_convert_api_torch_dot_log(self):
"""Test convert_api function work ok when convert api torch.log"""
code = """
torch.log(input)
"""
api_name = 'torch.log'
code = "torch.log(input)"
expected_ms_api_name = 'P.Log'
start = code.find(api_name)
replaced_code, new_start = self.converter_ins.convert_api(code, start, api_name)
replaced_code = self.converter_ins.convert_api(code)
assert replaced_code == code.replace('torch.log(input)',
'{}()(input)'.format(expected_ms_api_name))
assert new_start == start + len(expected_ms_api_name)
def test_convert_api_torch_dot_pow(self):
"""Test convert_api function work ok when convert api torch.pow"""
code = """
torch.pow(a, exp)
"""
api_name = 'torch.pow'
code = "torch.pow(a, exp)"
expected_ms_api_name = 'P.Pow'
start = code.find(api_name)
replaced_code, new_start = self.converter_ins.convert_api(code, start, api_name)
replaced_code = self.converter_ins.convert_api(code)
assert replaced_code == code.replace('torch.pow(a, exp)',
'{}()(a, exp)'.format(expected_ms_api_name))
assert new_start == start + len(expected_ms_api_name)
def test_convert_api_torch_dot_div(self):
"""Test convert_api function work ok when convert api torch.div"""
code = """
input = torch.randn(5)
other = torch.randn(5)
torch.div(input, other)
"""
api_name = 'torch.div'
code = "torch.div(input, other)"
expected_ms_api_name = 'P.Div'
start = code.find(api_name)
replaced_code, new_start = self.converter_ins.convert_api(code, start, api_name)
replaced_code = self.converter_ins.convert_api(code)
assert replaced_code == code.replace('torch.div(input, other)',
'{}()(input, other)'.format(expected_ms_api_name))
assert new_start == start + len(expected_ms_api_name)
def test_convert_api_torch_dot_sin(self):
"""Test convert_api function work ok when convert api torch.sin"""
code = """
torch.sin(input)
"""
api_name = 'torch.sin'
code = "torch.sin(input)"
expected_ms_api_name = 'P.Sin'
start = code.find(api_name)
replaced_code, new_start = self.converter_ins.convert_api(code, start, api_name)
replaced_code = self.converter_ins.convert_api(code)
assert replaced_code == code.replace('torch.sin(input)',
'{}()(input)'.format(expected_ms_api_name))
assert new_start == start + len(expected_ms_api_name)
def test_convert_api_torch_dot_sqrt(self):
"""Test convert_api function work ok when convert api torch.sqrt"""
code = """
torch.sqrt(input)
"""
api_name = 'torch.sqrt'
code = "torch.sqrt(input)"
expected_ms_api_name = 'P.Sqrt'
start = code.find(api_name)
replaced_code, new_start = self.converter_ins.convert_api(code, start, api_name)
replaced_code = self.converter_ins.convert_api(code)
assert replaced_code == code.replace('torch.sqrt(input)',
'{}()(input)'.format(expected_ms_api_name))
assert new_start == start + len(expected_ms_api_name)
def test_convert_api_torch_dot_eye_with_n(self):
"""Test convert_api function work ok when convert api torch.eye"""
code = """
torch.eye(3)
"""
api_name = 'torch.eye'
code = "torch.eye(3)"
expected_ms_api_name = 'P.Eye'
start = code.find(api_name)
replaced_code, new_start = self.converter_ins.convert_api(code, start, api_name)
replaced_code = self.converter_ins.convert_api(code)
assert replaced_code == code.replace('torch.eye(3)',
'{}()(3, 3, mindspore.int32)'.format(expected_ms_api_name))
assert new_start == start + len(expected_ms_api_name)
def test_convert_api_torch_dot_eye_with_m(self):
"""Test convert_api function work ok when convert api torch.eye"""
code = """
torch.eye(3, 4)
"""
api_name = 'torch.eye'
code = "torch.eye(3, 4)"
expected_ms_api_name = 'P.Eye'
start = code.find(api_name)
replaced_code, new_start = self.converter_ins.convert_api(code, start, api_name)
replaced_code = self.converter_ins.convert_api(code)
assert replaced_code == code.replace('torch.eye(3, 4)',
'{}()(3, 4, mindspore.int32)'.format(expected_ms_api_name))
assert new_start == start + len(expected_ms_api_name)
def test_convert_api_torch_dot_add_with_alpha_default(self):
"""Test convert_api function work ok when convert api torch.add"""
code = """
torch.add(input, value)
"""
api_name = 'torch.add'
code = "torch.add(input, value)"
expected_ms_api_name = 'P.TensorAdd'
start = code.find(api_name)
replaced_code, new_start = self.converter_ins.convert_api(code, start, api_name)
replaced_code = self.converter_ins.convert_api(code)
assert replaced_code == code.replace('torch.add(input, value)',
'{}()(input, value)'.format(expected_ms_api_name))
assert new_start == start + len(expected_ms_api_name)
def test_convert_api_torch_dot_add_with_alpha_not_default(self):
"""Test convert_api function work ok when convert api torch.add"""
code = """
torch.add(input, value, 3)
"""
api_name = 'torch.add'
code = "torch.add(input, value, 3)"
expected_ms_api_name = 'P.TensorAdd'
start = code.find(api_name)
replaced_code, new_start = self.converter_ins.convert_api(code, start, api_name)
replaced_code = self.converter_ins.convert_api(code)
assert replaced_code == code.replace('torch.add(input, value, 3)',
'{}()(input, value*3)'.format(expected_ms_api_name))
assert new_start == start + len(expected_ms_api_name)
# test convert_api with F ops
def test_convert_api_f_normalize(self):
"""Test convert_api function work ok when convert api F.normalize"""
code = """
input = torch.randn(2, 3, 5)
F.normalize(input)
"""
api_name = 'F.normalize'
start = code.find(api_name)
code = "F.normalize(input)"
expected_ms_api_name = 'P.L2Normalize'
replaced_code, new_start = self.converter_ins.convert_api(code, start, api_name)
replaced_code = self.converter_ins.convert_api(code)
assert replaced_code == code.replace('F.normalize(input)',
'{}(1, 1e-12)(input)'.format(expected_ms_api_name))
assert new_start == start + len(expected_ms_api_name)
def test_convert_api_f_sigmoid(self):
"""Test convert_api function work ok when convert api F.sigmoid"""
code = """
input = torch.randn(2, 3, 5)
F.sigmoid(input)
"""
api_name = 'F.sigmoid'
start = code.find(api_name)
code = "F.sigmoid(input)"
expected_ms_api_name = 'P.Sigmoid'
replaced_code, new_start = self.converter_ins.convert_api(code, start, api_name)
replaced_code = self.converter_ins.convert_api(code)
assert replaced_code == code.replace('F.sigmoid(input)',
'{}()(input)'.format(expected_ms_api_name))
assert new_start == start + len(expected_ms_api_name)
# test convert_api with tensor dot ops
def test_convert_api_tensor_dot_repeat(self):
"""Test convert_api function work ok when convert api .repeat"""
code = """
x.repeat(4, 2)
"""
api_name = '.repeat'
start = code.find(api_name)
code = "x.repeat(4, 2)"
expected_ms_api_name = 'P.Tile'
replaced_code, new_start = self.converter_ins.convert_api(code, start, api_name)
replaced_code = self.converter_ins.convert_api(code)
assert replaced_code == code.replace('x.repeat(4, 2)',
'{}()(x, {})'.format(expected_ms_api_name, '(4, 2,)'))
assert new_start == start + len(expected_ms_api_name)
def test_convert_api_tensor_dot_permute(self):
"""Test convert_api function work ok when convert api .permute"""
code = """
x.permute(2, 0, 1)
"""
api_name = '.permute'
start = code.find(api_name)
code = "x.permute(2, 0, 1)"
expected_ms_api_name = 'P.Transpose'
replaced_code, new_start = self.converter_ins.convert_api(code, start, api_name)
replaced_code = self.converter_ins.convert_api(code)
assert replaced_code == code.replace('x.permute(2, 0, 1)',
'{}()(x, (2, 0, 1,))'.format(expected_ms_api_name))
assert new_start == start + len(expected_ms_api_name)
......@@ -15,7 +15,6 @@
"""Test forward_call module."""
import ast
import textwrap
from unittest.mock import patch
from mindinsight.mindconverter.forward_call import ForwardCall
......@@ -50,12 +49,10 @@ class TestForwardCall:
return out
""")
@patch.object(ForwardCall, 'process')
def test_process(self, mock_process):
def test_process(self):
"""Test the function of visit ast tree to find out forward functions."""
mock_process.return_value = None
forward_call = ForwardCall("mock")
forward_call.visit(ast.parse(self.source))
ast_tree = ast.parse(self.source)
forward_call = ForwardCall(ast_tree)
expect_calls = ['TestNet.forward',
'TestNet.forward1',
......@@ -70,6 +67,6 @@ class TestForwardCall:
'TestNet.fc3',
]
expect_calls.sort()
real_calls = list(forward_call.calls)
real_calls = list(forward_call.calls.keys())
real_calls.sort()
assert real_calls == expect_calls
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册