提交 16e94c3f 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!135 torch2ms convert

Merge pull request !135 from quyongxiu1/br_qyx_0520
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Create a logger."""
from mindinsight.utils.log import setup_logger
logger = setup_logger("mindconverter", "mindconverter")
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless REQUIRED by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""API config"""
import ast
from collections import OrderedDict
from functools import partial
import json
import os
import pasta
from mindinsight.mindconverter.enums import RequriedType
from mindinsight.mindconverter.common.log import logger
REQUIRED = RequriedType.REQUIRED.name
UNREQUIRED = RequriedType.UNREQUIRED.name
class APIPt:
"""Base API for args parse, and API for one frame."""
def __init__(self, name: str, params: OrderedDict):
self.name = name
self.params = OrderedDict()
for k, value in params.items():
self.params[k] = self.to_str(value)
@staticmethod
def to_str(value):
"""
Trans value to str.
Args:
value (Union[str,Number,int]): Each value for params of OrderedDict.
Returns:
str, str type of value.
"""
if value is REQUIRED:
return value
if isinstance(value, str):
return "'{}'".format(value)
return str(value)
def parse_args(self, call_name: str, args_str: str):
"""
Parse call_name and args_str.
Args:
call_name (str): str of the call function, etc.
args_str (str): str of args for function, which starts with '(' and end with ')'.
Returns:
OrderedDict, all args parsed.
Raises:
ValueError: If can not use ast to parse or the required parse node not type of ast.Call,
or the given args_str not valid.
"""
# expr is REQUIRED to meet (**) format
if not (len(args_str) >= 2 and args_str[0] == "(" and args_str[-1] == ")"):
raise ValueError('[{}] is think as args str, it should start with "(" and end with ")"'.format(args_str))
try:
ast_node = ast.parse("whatever_call_name" + args_str)
call_node = ast_node.body[0].value
if not isinstance(call_node, ast.Call):
raise ValueError('call name with args str [{}] not instance of ast.Call'.format(args_str))
except:
raise ValueError("can't parse code:\n{}".format(args_str))
# regard all actual parameter as one parameter
if len(self.params) == 1:
k = list(self.params.keys())[0]
if k.startswith('*'):
value = args_str[1:-1]
return OrderedDict([(k, value), ("call_name", call_name)])
args = OrderedDict()
# param which name not assigned
param_iter = iter(self.params.keys())
if len(call_node.args) > len(self.params):
raise ValueError('Parse args of torch in {}, but there is problems with params'.format(call_name))
for arg in call_node.args:
if isinstance(arg, ast.Starred):
logger.debug("Find *%s", arg.value.id)
args['*'] = arg.value.id
else:
# remove \n
args[next(param_iter)] = pasta.dump(arg).strip()
# params which name is assigned
for keyword in call_node.keywords:
if keyword.arg is None:
logger.info("Find **%s", keyword.value.id)
args['**'] = keyword.value.id
else:
# remove \n
args[keyword.arg] = pasta.dump(keyword.value).strip()
args["call_name"] = call_name
return args
class APIMs(APIPt):
"""API for MindSpore"""
def __init__(self, name: str, params: OrderedDict, p_attrs=None):
self.is_primitive = name.startswith('P.')
if self.is_primitive:
self.p_attrs = p_attrs if p_attrs else set()
super(APIMs, self).__init__(name, params)
def create_args(self, params_pt, args_pt, ms2pt_map, explicit_map):
"""
Create args for MindSpore according to other frame op info.
Args:
params_pt (OrderedDict): Params used for initialize function of APIPt.
args_pt (OrderedDict): Args parsed from APIPt.
ms2pt_map (dict): Dict of params mapping relation for ops between frames.
explicit_map(func): Function to generate mapping relation for ops between frames.
Returns:
OrderedDict, args for MindSpore.
"""
args = OrderedDict()
# traverse MindSpore's params
for k in self.params.keys():
# has relevant param? yes
if k in ms2pt_map:
if ms2pt_map[k] in args_pt:
# user assigned value
args[k] = args_pt[ms2pt_map[k]]
elif self.params[k] != params_pt[ms2pt_map[k]]:
# user didn't assigned value, but initial value different between 2 frames
args[k] = params_pt[ms2pt_map[k]]
# has relevant param? no
else:
# params forced to display
if k in explicit_map:
args[k] = explicit_map[k]
elif self.params[k] is REQUIRED:
args[k] = "<REQUIRED>"
# find * or ** in frame actual parameters
for star in ('*', '**'):
if star in args_pt:
args[star] = args_pt[star]
return args
class MappingHelper:
"""Mapping from one frame to another frame"""
def __init__(self, ms_api: APIMs, pt_api: APIPt, **kwargs):
ms2pt_mapping = kwargs.get('ms2pt_mapping')
gen_explicit_map = kwargs.get('gen_explicit_map')
export_key = kwargs.get('export_key')
if ms2pt_mapping is None:
ms2pt_mapping = {}
if gen_explicit_map is None:
gen_explicit_map = lambda params_pt, args_pt: {}
self.ms_api = ms_api
self.pt_api = pt_api
self.ms2pt_mapping = ms2pt_mapping
self.gen_explicit_map = gen_explicit_map
if export_key is not None:
self.export_key = export_key
else:
self.export_key = not ms_api.is_primitive
def gen_args_expr(self, args):
"""
Generate str assignment statement from given dict.
Args:
args (OrderedDict): Key, value pairs for assignment source.
Returns:
str, generated str.
"""
expr = ''
for key, value in args.items():
if expr:
expr += ', '
sym = '' if key in ('*', '**') else '='
if self.export_key:
expr += key + sym
expr += value
return expr
def gen_args_expr_for_p(self, args, p_attrs):
"""
Generate str assignment statement from given dict for primitive and not primitive.
Args:
args (OrderedDict): Key, value pairs for assignment source.
p_attrs (set): Exclusive params for operator.
Returns:
tuple, generated str for primitive, generated str for not primitive.
"""
args_attrs = OrderedDict([(k, v) for k, v in args.items() if k in p_attrs])
args_ios = OrderedDict([(k, v) for k, v in args.items() if k not in p_attrs])
return self.gen_args_expr(args_attrs), self.gen_args_expr(args_ios)
def convert(self, call_name_pt: str, args_str_pt: str):
"""
Convert code sentence to MindSpore code sentence.
Args:
call_name_pt (str): str of the call function, etc.
args_str_pt (str): str of args for function, which starts with '(' and end with ')'.
Returns:
str, converted code sentence for MindSpore.
"""
# all value for args_pt is str
args_pt = self.pt_api.parse_args(call_name_pt, args_str_pt)
# all value for args_ms is str
explicit_map = self.gen_explicit_map(self.pt_api.params, args_pt)
args_ms = self.ms_api.create_args(self.pt_api.params, args_pt, self.ms2pt_mapping, explicit_map)
if self.ms_api.is_primitive:
if self.pt_api.name == '.size' and 'idx' in args_pt:
args_expr = self.gen_args_expr(args_ms)
expr_ms = "%s()(%s)[%s]" % (self.ms_api.name, args_expr, args_pt['idx'])
else:
expr_attrs, expr_ios = self.gen_args_expr_for_p(args_ms, self.ms_api.p_attrs)
expr_ms = "%s(%s)(%s)" % (self.ms_api.name, expr_attrs, expr_ios)
else:
ms_expr = self.gen_args_expr(args_ms)
expr_ms = "%s(%s)" % (self.ms_api.name, ms_expr)
return expr_ms
def gen_explicit_map_nn_sequential(_, args_pt):
"""
Generate explicit_map for nn.Sequential.
Args:
args_pt (dict): Args for APIPt.
Returns:
dict, map between frames.
"""
args = args_pt['*args']
return {"*args": "[{}]".format(args)}
def gen_explicit_map_nn_maxpool2d(params_pt, args_pt):
"""
Generate explicit_map for nn.MaxPool2d.
Args:
params_pt (dict): Params for APIPt.
args_pt (dict): Args for APIPt.
Returns:
dict, map between frames.
"""
if 'padding' in args_pt:
padding = args_pt['padding']
else:
padding = params_pt['padding']
if padding.strip() in ("0", "(0,0)", "(0, 0)"):
pad_mode = "'valid'"
else:
pad_mode = "'same'"
return {"pad_mode": pad_mode}
def gen_explicit_map_f_max_pool2d(params_pt, args_pt):
"""
Generate explicit_map for F.MaxPool2d.
Args:
params_pt (dict): Params for APIPt.
args_pt (dict): Args for APIPt.
Returns:
dict, map between frames.
"""
if 'padding' in args_pt:
padding = args_pt['padding']
else:
padding = params_pt['padding']
if padding.strip() in ("0", "(0,0)", "(0, 0)"):
padding = "'valid'"
else:
padding = "'same'"
return {"padding": padding}
def gen_explicit_map_one_delta(params_pt, args_pt, k_ms, k_pt):
"""
Generate explicit_map for which include mapping relationship is `1 - k_ms = k_pt`.
Args:
params_pt (dict): Params for APIPt.
args_pt (dict): Args for APIPt.
Returns:
dict, map between frames.
"""
value = args_pt[k_pt] if k_pt in args_pt else params_pt[k_pt]
value = value.strip()
def is_number(string):
try:
float(string)
return True
except ValueError:
return False
if is_number(value):
return {k_ms: str(1 - float(value))}
return {k_ms: "1.0 - " + value}
def load_json_file(file_path):
"""
Load data from given json file path.
Args:
file_path (str): The file to load json data from.
Returns:
list, the list data stored in file_path.
"""
with open(file_path, 'r', encoding='utf-8') as file:
info = json.loads(file.read())
return info
# ---------------------------- mappings ----------------------------
NN_MAPPING = {
'nn.Sequential': MappingHelper(**{"ms_api": APIMs('nn.SequentialCell', OrderedDict([('*args', REQUIRED)])),
"pt_api": APIPt('nn.Sequential', OrderedDict([('*args', REQUIRED)])),
"gen_explicit_map": gen_explicit_map_nn_sequential,
"export_key": False
}),
'nn.Conv2d': MappingHelper(**{"ms_api": APIMs('nn.Conv2d', OrderedDict(in_channels=REQUIRED,
out_channels=REQUIRED,
kernel_size=REQUIRED,
stride=1,
pad_mode='same',
padding=0,
dilation=1,
group=1,
has_bias=False,
weight_init='normal',
bias_init='zeros')),
"pt_api": APIPt('nn.Conv2d', OrderedDict(in_channels=REQUIRED,
out_channels=REQUIRED,
kernel_size=REQUIRED,
stride=1,
padding=0,
dilation=1,
groups=1,
bias=True,
padding_mode='zeros')),
"ms2pt_mapping": {'in_channels': 'in_channels',
'out_channels': 'out_channels',
'kernel_size': 'kernel_size',
'stride': 'stride',
'padding': 'padding',
'dilation': 'dilation',
'group': 'groups',
'has_bias': 'bias'},
"gen_explicit_map": (lambda params_pt, args_pt: {"pad_mode": "'pad'"})
}),
'nn.BatchNorm2d': MappingHelper(**{"ms_api": APIMs('nn.BatchNorm2d', OrderedDict(num_features=REQUIRED,
eps=1e-5,
momentum=0.9,
affine=True,
gamma_init='ones',
beta_init='zeros',
moving_mean_init='zeros',
moving_var_init='ones',
use_batch_statistics=True)),
"pt_api": APIPt('nn.BatchNorm2d', OrderedDict(num_features=REQUIRED,
eps=1e-5,
momentum=0.1,
affine=True,
track_running_stats=True)),
"ms2pt_mapping": {"num_features": "num_features",
"eps": "eps",
"affine": "affine",
"use_batch_statistics": "track_running_stats"},
"gen_explicit_map": partial(gen_explicit_map_one_delta,
k_ms="momentum", k_pt="momentum")
}),
'nn.ReLU': MappingHelper(**{"ms_api": APIMs('nn.ReLU', OrderedDict()),
"pt_api": APIPt('nn.ReLU', OrderedDict(inplace=False)),
"ms2pt_mapping": {}}),
'nn.ReLU6': MappingHelper(**{"ms_api": APIMs('nn.ReLU6', OrderedDict()),
"pt_api": APIPt('nn.ReLU6', OrderedDict(inplace=False)),
"ms2pt_mapping": {}}),
'nn.Linear': MappingHelper(**{"ms_api": APIMs('nn.Dense', OrderedDict(in_channels=REQUIRED,
out_channels=REQUIRED,
weight_init='normal',
bias_init='zeros',
has_bias=True,
activation=None)),
"pt_api": APIPt('nn.Linear', OrderedDict(in_features=REQUIRED,
out_features=REQUIRED,
bias=True)),
"ms2pt_mapping": {"in_channels": "in_features",
"out_channels": "out_features",
"has_bias": "bias"}
}),
'nn.MaxPool2d': MappingHelper(**{"ms_api": APIMs('nn.MaxPool2d', OrderedDict(kernel_size=1,
stride=1,
pad_mode="valid")),
"pt_api": APIPt('nn.MaxPool2d', OrderedDict(kernel_size=REQUIRED,
stride=None,
padding=0,
dilation=1,
return_indices=False,
ceil_mode="False")),
"ms2pt_mapping": {"kernel_size": "kernel_size",
"stride": "stride"},
"gen_explicit_map": gen_explicit_map_nn_maxpool2d
}),
'nn.AvgPool2d': MappingHelper(**{"ms_api": APIMs('nn.AvgPool2d', OrderedDict(kernel_size=1,
stride=1,
pad_mode="valid")),
"pt_api": APIPt('nn.AvgPool2d', OrderedDict(kernel_size=REQUIRED,
stride=None,
padding=0,
dilation=1,
return_indices=False,
ceil_mode="False")),
"ms2pt_mapping": {"kernel_size": "kernel_size",
"stride": "stride"},
"gen_explicit_map": gen_explicit_map_nn_maxpool2d
}),
'nn.Dropout': MappingHelper(**{"ms_api": APIMs('nn.Dropout', OrderedDict(keep_prob=0.5,
seed0=0,
seed1=0,
dtype="mstype.float32")),
"pt_api": APIPt('nn.Dropout', OrderedDict(p=0.5,
inplace=False)),
"ms2pt_mapping": {"keep_prob": "p"},
"gen_explicit_map": partial(gen_explicit_map_one_delta,
k_ms="keep_prob", k_pt="p")
})
}
# set alias nn. = torch.nn.
NN_MAPPING.update({"torch." + k: v for k, v in NN_MAPPING.items()})
F_MAPPING = {
'F.relu': MappingHelper(**{"ms_api": APIMs('P.ReLU', OrderedDict(input=REQUIRED)),
"pt_api": APIPt('F.relu', OrderedDict(input=REQUIRED, inplace=False)),
"ms2pt_mapping": {"input": "input"},
}),
'F.relu6': MappingHelper(**{"ms_api": APIMs('P.ReLU6', OrderedDict(input=REQUIRED)),
"pt_api": APIPt('F.relu6', OrderedDict(input=REQUIRED, inplace=False)),
"ms2pt_mapping": {"input": "input"},
}),
'F.max_pool2d': MappingHelper(**{"ms_api": APIMs('P.MaxPool', OrderedDict(ksize=1,
strides=1,
padding="valid",
input=REQUIRED),
p_attrs={"ksize", "strides", "padding"}),
"pt_api": APIPt('F.max_pool2d', OrderedDict(input=REQUIRED,
kernel_size=REQUIRED,
stride=None,
padding=0,
dilation=1,
ceil_mode=False,
return_indices=False)),
"ms2pt_mapping": {"ksize": "kernel_size",
"strides": "stride",
"input": "input",
},
"gen_explicit_map": gen_explicit_map_f_max_pool2d
}),
'F.avg_pool2d': MappingHelper(**{"ms_api": APIMs('P.AvgPool', OrderedDict(ksize=1,
strides=1,
padding="valid",
input=REQUIRED),
p_attrs={"ksize", "strides", "padding"}),
"pt_api": APIPt('F.avg_pool2d', OrderedDict(input=REQUIRED,
kernel_size=REQUIRED,
stride=None,
padding=0,
dilation=1,
ceil_mode=False,
return_indices=False)),
"ms2pt_mapping": {"ksize": "kernel_size",
"strides": "stride",
"input": "input",
},
"gen_explicit_map": gen_explicit_map_f_max_pool2d
}),
}
# set alias F = nn.functional = torch.nn.functional
nn_functional_d = {"nn.functional." + k[2:]: v for k, v in F_MAPPING.items()}
torch_nn_functional_d = {"torch.nn.functional." + k[2:]: v for k, v in F_MAPPING.items()}
F_MAPPING.update(nn_functional_d)
F_MAPPING.update(torch_nn_functional_d)
TORCH_DOT_MAPPING = {
'torch.flatten': MappingHelper(**{"ms_api": APIMs('P.Flatten', OrderedDict(input=REQUIRED)),
"pt_api": APIPt('torch.flatten', OrderedDict(input=REQUIRED,
start_dim=0,
end_dim=-1)),
"ms2pt_mapping": {"input": "input"}
}),
'torch.cat': MappingHelper(**{"ms_api": APIMs('P.Concat',
OrderedDict(axis=0, input=REQUIRED),
p_attrs={"axis"}),
"pt_api": APIPt('torch.flatten', OrderedDict(tensors=REQUIRED, dim=0, out=None)),
"ms2pt_mapping": {"input": "tensors",
"axis": "dim"}
}),
}
TENSOR_DOT_MAPPING = {
'.view': MappingHelper(**{"ms_api": APIMs('P.Reshape', OrderedDict(x=REQUIRED, shape=REQUIRED)),
"pt_api": APIPt('.view', OrderedDict([('*shape', REQUIRED)])),
"ms2pt_mapping": {"x": "call_name"},
"gen_explicit_map": (lambda params_pt, args_pt: {"shape": "(" + args_pt["*shape"] + ",)"})
}),
'.size': MappingHelper(**{"ms_api": APIMs('P.Shape', OrderedDict(x=REQUIRED)),
"pt_api": APIPt('.size', OrderedDict([('idx', REQUIRED)])),
"ms2pt_mapping": {"x": "call_name"}
}),
'.flatten': MappingHelper(**{"ms_api": APIMs('P.Flatten', OrderedDict(input=REQUIRED)),
"pt_api": APIPt('.flatten', OrderedDict(start_dim=0,
end_dim=-1)),
"ms2pt_mapping": {"input": "call_name"}
}),
'.reshape': MappingHelper(**{"ms_api": APIMs('P.Reshape', OrderedDict(x=REQUIRED, shape=REQUIRED)),
"pt_api": APIPt('.reshape', OrderedDict([('*shape', REQUIRED)])),
"ms2pt_mapping": {"x": "call_name"},
"gen_explicit_map": (
lambda params_pt, args_pt: {"shape": "(" + args_pt["*shape"] + ",)"})
}),
'.mean': MappingHelper(**{"ms_api": APIMs('P.ReduceMean', OrderedDict(keep_dims=False,
input=REQUIRED,
axis=()),
p_attrs={"keep_dims"}),
"pt_api": APIPt('.mean', OrderedDict(dim=None,
keepdim=False)),
"ms2pt_mapping": {"keep_dims": "keepdim",
"axis": "dim",
"input": "call_name"},
}),
'.squeeze': MappingHelper(**{"ms_api": APIMs('P.ReduceMean', OrderedDict(input=REQUIRED,
axis=()),
p_attrs={"axis"}),
"pt_api": APIPt('.squeeze', OrderedDict(dim=None)),
"ms2pt_mapping": {"axis": "dim",
"input": "call_name"},
}),
}
ALL_MAPPING = {**NN_MAPPING, **F_MAPPING, **TORCH_DOT_MAPPING, **TENSOR_DOT_MAPPING}
# ---------------------------- api list support or not support ----------------------------
NN_LIST_PATH = os.path.realpath(os.path.join(os.path.dirname(__file__), 'nn_list.json'))
NN_LIST = load_json_file(NN_LIST_PATH)
# set alias nn. = torch.nn.
NN_LIST += ["torch." + name for name in NN_LIST]
NN_SUPPORTED = [x for x in NN_LIST if x in ALL_MAPPING]
NN_UNSUPPORTED = [x for x in NN_LIST if x not in ALL_MAPPING]
F_LIST_PATH = os.path.realpath(os.path.join(os.path.dirname(__file__), 'f_list.json'))
F_LIST = load_json_file(F_LIST_PATH)
# set alias F = nn.functional = torch.nn.functional
F_LIST += ["F." + name[len("torch.nn.functional."):] for name in F_LIST] + \
[name[len("torch."):] for name in F_LIST]
F_SUPPORTED = [x for x in F_LIST if x in ALL_MAPPING]
F_UNSUPPORTED = [x for x in F_LIST if x not in ALL_MAPPING]
TORCH_DOT_LIST_PATH = os.path.realpath(os.path.join(os.path.dirname(__file__), 'torch_dot_list.json'))
TORCH_DOT_LIST = load_json_file(TORCH_DOT_LIST_PATH)
TORCH_DOT_SUPPORTED = [x for x in TORCH_DOT_LIST if x in ALL_MAPPING]
TORCH_DOT_UNSUPPORTED = [x for x in TORCH_DOT_LIST if x not in ALL_MAPPING]
TENSOR_DOT_LIST_PATH = os.path.realpath(os.path.join(os.path.dirname(__file__), 'tensor_dot_list.json'))
TENSOR_DOT_LIST = load_json_file(TENSOR_DOT_LIST_PATH)
TENSOR_DOT_SUPPORTED = [x for x in TENSOR_DOT_LIST if x in ALL_MAPPING]
TENSOR_DOT_UNSUPPORTED = [x for x in TENSOR_DOT_LIST if x not in ALL_MAPPING]
ALL_2P_LIST = F_LIST + TORCH_DOT_LIST + TENSOR_DOT_LIST
ALL_TORCH_APIS = NN_LIST + F_LIST + TORCH_DOT_LIST + TENSOR_DOT_LIST
ALL_SUPPORTED = NN_SUPPORTED + F_SUPPORTED + TORCH_DOT_SUPPORTED + TENSOR_DOT_SUPPORTED
ALL_UNSUPPORTED = NN_UNSUPPORTED + F_UNSUPPORTED + TORCH_DOT_UNSUPPORTED + TENSOR_DOT_UNSUPPORTED
UNSUPPORTED_WARN_INFOS = {
"nn.AdaptiveAvgPool2d": "maybe could convert to P.ReduceMean",
"F.adaptive_avg_pool2d": "maybe could convert to P.ReduceMean",
"F.dropout": "please use nn.Dropout in __init__()",
}
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless REQUIRED by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""main module"""
import inspect
import copy
import importlib
import os
import stat
from mindinsight.mindconverter.config import ALL_MAPPING
from mindinsight.mindconverter.config import NN_LIST
from mindinsight.mindconverter.config import ALL_TORCH_APIS
from mindinsight.mindconverter.config import ALL_2P_LIST
from mindinsight.mindconverter.config import UNSUPPORTED_WARN_INFOS
from mindinsight.mindconverter.common.log import logger
def is_local_defined(obj, member):
"""
Check if obj and member are both defined in the same source file.
Args:
obj (Union[object, module]): A module or a class.
member (func): A function of obj.
Returns:
bool, True or False.
"""
srcfile = inspect.getsourcefile(obj)
return inspect.getsourcefile(member) == srcfile
def is_valid_module(obj, member):
"""
Check if obj and member defined in same source file and member is inherited from torch.nn.Module.
Args:
obj (Union[object, module]): A module or a class.
member (func): A function.
Returns:
bool, True or False.
"""
return inspect.isclass(member) and (member.__base__.__name__ == 'Module') and is_local_defined(obj, member)
def is_valid_function(obj, member):
"""
Check if member is function and defined in the file same as obj.
Args:
obj (Union[object, module]: The obj.
member (func): The func.
Returns:
bool, True or False.
"""
return inspect.isfunction(member) and is_local_defined(obj, member)
def find_left_parentheses(string, right):
"""
Find index of the first left parenthesis.
Args:
string (str): A line of code.
right (int): Max index of string, same as `len(string) -1`.
Returns:
int, index of the first parenthesis.
Raises:
ValueError: If line of code doesn't contain any pair of `()` or `(` and `)` are not paired.
"""
if string[right] != ')':
raise ValueError('code [{}] at index {} not ")".'.format(string, right))
stack = []
for i in range(right, -1, -1):
if string[i] == ')':
stack.append(')')
elif string[i] == '(':
stack.pop()
if not stack:
return i
raise ValueError("{} should contain ()".format(string))
def find_right_parentheses(string, left):
"""
Find first index of right parenthesis which make all left parenthesis make sense.
Args:
string (str): A line of code.
left (int): Start index of string to find from.
Returns:
int, index of the found right parenthesis.
Raises:
ValueError: If line of code doesn't contain any pair of `()` or `(` and `)` are not paired.
"""
stack = []
for i in range(left, len(string)):
if string[i] == '(':
stack.append('(')
elif string[i] == ')':
stack.pop()
if not stack:
return i
raise ValueError("{} should contain ()".format(string))
def get_call_name(code, end):
"""
Traverse code in a reversed function from index end and get the call name and start index of the call name, if call
name not found, return a null character string and -1
Args:
code (str): The str of code to find from.
end (int): Start index to find.
Returns:
str, founded api name if found, else a null character string.
int, start index of founded api name, -1 if api name not found
"""
stack = []
for i in range(end - 1, -1, -1):
if code[i] in ["(", "[", "{"]:
if stack:
stack.pop()
else:
return code[i + 1:end], i + 1
elif code[i] in [")", "]", "}"]:
stack.append(code[i])
elif stack:
continue
elif not (code[i].isalpha() or code[i].isdigit() or code[i] == '_' or code[i] == '.'):
return code[i + 1:end], i + 1
return "", -1
def convert_api(code, start, api_name=""):
"""
Convert api_name in code to MindSpore api with start as a start index, if api_name is a python api, code will not
convert.
Args:
code (str): The str code to convert.
start (int): The index of code to start convert from.
api_name (str): The api name to convert.
Returns:
str, the converted code.
int, index of converted api_name in code.
"""
# handle format like .shape(
if api_name.startswith('.'):
call_name, new_start = get_call_name(code, start)
if start == -1 or call_name == "self":
return code, start + 1
else:
call_name = api_name
new_start = start
# find full api expected to be converted. eg:expr="nn.Conv2d(1,2,3)" args_str="(1,2,3)"
left = code.find("(", start)
if left == -1:
raise ValueError('"(" not found, {} should work with "("'.format(call_name))
right = find_right_parentheses(code, left)
end = right
expr = code[start:end + 1]
args_str = code[left:right + 1]
map_helper = ALL_MAPPING[api_name]
new_expr = map_helper.convert(call_name, args_str)
next_newline = code.find("\n", end + 1)
fill_num = (expr.count("\n") - new_expr.count("\n"))
if next_newline != -1:
code = code[:new_start] + new_expr + code[end + 1:next_newline] + ("\n" * fill_num) + code[next_newline:]
else:
code = code[:new_start] + new_expr + ")" + ("\n" * fill_num) + code[end + 2:]
return code, start + len(map_helper.ms_api.name)
def find_api(code, i, is_forward):
"""
Find api name from code with a start index i, check api name ok with a is_forward condition.
Args:
code (str): The code from which to find api name.
i (int): The start index to find.
is_forward (bool): Check if the found api name ok.
Returns:
str, api name if find api name and check ok with is_forward condition, else a null character string.
"""
if code[i:].startswith("nn.") \
or code[i:].startswith("F.") \
or code[i:].startswith("torch.") \
or code[i:].startswith('.'):
j = code.find('(', i)
if j != -1 and code[i:j] in ALL_TORCH_APIS:
api_name = code[i:j]
if (not is_forward and api_name in NN_LIST) or (is_forward and api_name in ALL_2P_LIST):
return api_name
return ""
def convert_function(fun_name, fun, is_forward):
"""
Convert a PyTorch function into MindSpore function.
Args:
fun_name (str): The str of function name.
fun (func): The function to convert.
is_forward (bool): If the function is defined in forward function in nn.Module in torch.
Returns:
dict, old code and converted code map if convert happens, else {}.
"""
_, line_no = inspect.getsourcelines(fun)
logger.info("Line %3d: start converting function %s()", line_no, fun_name)
code = inspect.getsource(fun)
code_saved = copy.copy(code)
i = 0
while i < len(code):
api_name = find_api(code, i, is_forward)
if api_name:
line_no1 = line_no + code[:i].count('\n')
if api_name in ALL_MAPPING:
logger.info("Line %3d start converting API: %s", line_no1, api_name)
code, i = convert_api(code, i, api_name)
continue
warn_info = ". " + UNSUPPORTED_WARN_INFOS[api_name] if api_name in UNSUPPORTED_WARN_INFOS else ""
logger.warning("Line %3d: found unsupported API: %s%s", line_no1, api_name, warn_info)
i += 1
return {code_saved: code} if code_saved != code else {}
def judge_forward(name, forward_list):
"""
Check if function is a forward function.
Args:
name (str): The function name.
forward_list (set): A set of forward function.
Returns:
bool, True or False
"""
is_forward = name in forward_list or name.split(".")[-1] == "forward"
if is_forward:
logger.debug("%s is a forward function", name)
return is_forward
def convert_module(module_name, module, forward_list):
"""
Convert a PyTorch module code into MindSpore module code.
Args:
module_name (str): The module's name.
module (module): The module to convert.
forward_list (set): A set of forward function.
Returns:
dict, map of old code and converted code.
"""
_, line_no = inspect.getsourcelines(module)
logger.info("Line {:3d}: start converting nn.Module {}".format(line_no, module_name))
mapped = {}
for name, member in inspect.getmembers(module):
if is_valid_function(module, member):
is_forward = judge_forward("{}.{}".format(module_name, name), forward_list)
mapped.update(convert_function(name, member, is_forward))
return mapped
def get_mapping(import_mod, forward_list):
"""
Convert code of a module and get mapping of old code and convert code.
Args:
import_mod (module): The module to convert.
forward_list (set): A set of forward function.
Returns:
dict, mapping for old code and converted code of the module
"""
mapping = {}
tasks = []
for name, member in inspect.getmembers(import_mod):
if is_valid_module(import_mod, member):
_, line_no = inspect.getsourcelines(member)
tasks.append((line_no, convert_module, (name, member, forward_list)))
elif is_valid_function(import_mod, member):
_, line_no = inspect.getsourcelines(member)
is_forward = judge_forward("{}.{}".format(import_mod, name), forward_list)
tasks.append((line_no, convert_function, (name, member, is_forward)))
tasks.sort()
for _, convert_fun, args in tasks:
mapping.update(convert_fun(*args))
return mapping
def convert(import_name, nn_module):
"""
The entrance for convert a module's code, code converted will be write to file called out.py.
Args:
import_name (str): The module from which to import the module to convert.
nn_module (str): Name of the module to convert.
"""
logger.info("Start converting %s.%s", import_name, nn_module)
import_mod = importlib.import_module(import_name)
forward_list = set()
logger.debug("Forward_list: %s", forward_list)
# replace python function under nn.Modlue
mapping = get_mapping(import_mod, forward_list)
code = inspect.getsource(import_mod)
for key, value in mapping.items():
code = code.replace(key, value)
code = 'import mindspore.ops.operations as P\n' + code
code = 'import mindspore.nn as nn\n' + code
code = 'import mindspore\n' + code
code = code.replace('import torch', '# import torch')
code = code.replace('from torch', '# from torch')
code = code.replace('(nn.Module):', '(nn.Cell):')
code = code.replace('forward(', 'construct(')
code = code.replace('nn.Linear', 'nn.Dense')
code = code.replace('(nn.Sequential)', '(nn.SequentialCell)')
code = code.replace('nn.init.', 'pass # nn.init.')
flags = os.O_WRONLY | os.O_CREAT | os.O_EXCL
modes = stat.S_IWUSR | stat.S_IRUSR
with os.fdopen(os.open('out.py', flags, modes), 'w') as file:
file.write(code)
logger.info("Convert success. Result is wrote to out.py\n")
if __name__ == '__main__':
convert('torchvision.models.resnet', 'resnet18')
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Enums."""
from enum import Enum
class RequriedType(Enum):
"""If param is required"""
REQUIRED = 1
UNREQUIRED = 2
[
"torch.nn.functional.conv1d",
"torch.nn.functional.conv2d",
"torch.nn.functional.conv3d",
"torch.nn.functional.conv_transpose1d",
"torch.nn.functional.conv_transpose2d",
"torch.nn.functional.conv_transpose3d",
"torch.nn.functional.unfold",
"torch.nn.functional.fold",
"torch.nn.functional.avg_pool1d",
"torch.nn.functional.avg_pool2d",
"torch.nn.functional.avg_pool3d",
"torch.nn.functional.max_pool1d",
"torch.nn.functional.max_pool2d",
"torch.nn.functional.max_pool3d",
"torch.nn.functional.max_unpool1d",
"torch.nn.functional.max_unpool2d",
"torch.nn.functional.max_unpool3d",
"torch.nn.functional.lp_pool1d",
"torch.nn.functional.lp_pool2d",
"torch.nn.functional.adaptive_max_pool1d",
"torch.nn.functional.adaptive_max_pool2d",
"torch.nn.functional.adaptive_max_pool3d",
"torch.nn.functional.adaptive_avg_pool1d",
"torch.nn.functional.adaptive_avg_pool2d",
"torch.nn.functional.adaptive_avg_pool3d",
"torch.nn.functional.threshold",
"torch.nn.functional.threshold_",
"torch.nn.functional.relu",
"torch.nn.functional.relu_",
"torch.nn.functional.hardtanh",
"torch.nn.functional.hardtanh_",
"torch.nn.functional.relu6",
"torch.nn.functional.elu",
"torch.nn.functional.elu_",
"torch.nn.functional.selu",
"torch.nn.functional.celu",
"torch.nn.functional.leaky_relu",
"torch.nn.functional.leaky_relu_",
"torch.nn.functional.prelu",
"torch.nn.functional.rrelu",
"torch.nn.functional.rrelu_",
"torch.nn.functional.glu",
"torch.nn.functional.gelu",
"torch.nn.functional.logsigmoid",
"torch.nn.functional.hardshrink",
"torch.nn.functional.tanhshrink",
"torch.nn.functional.softsign",
"torch.nn.functional.softplus",
"torch.nn.functional.softmin",
"torch.nn.functional.softmax",
"torch.nn.functional.softshrink",
"torch.nn.functional.gumbel_softmax",
"torch.nn.functional.log_softmax",
"torch.nn.functional.tanh",
"torch.nn.functional.sigmoid",
"torch.nn.functional.batch_norm",
"torch.nn.functional.instance_norm",
"torch.nn.functional.layer_norm",
"torch.nn.functional.local_response_norm",
"torch.nn.functional.normalize",
"torch.nn.functional.linear",
"torch.nn.functional.bilinear",
"torch.nn.functional.dropout",
"torch.nn.functional.alpha_dropout",
"torch.nn.functional.dropout2d",
"torch.nn.functional.dropout3d",
"torch.nn.functional.embedding",
"torch.nn.functional.embedding_bag",
"torch.nn.functional.one_hot",
"torch.nn.functional.pairwise_distance",
"torch.nn.functional.cosine_similarity",
"torch.nn.functional.pdist",
"torch.nn.functional.binary_cross_entropy",
"torch.nn.functional.binary_cross_entropy_with_logits",
"torch.nn.functional.poisson_nll_loss",
"torch.nn.functional.cosine_embedding_loss",
"torch.nn.functional.cross_entropy",
"torch.nn.functional.ctc_loss",
"torch.nn.functional.log_softmax",
"torch.nn.functional.hinge_embedding_loss",
"torch.nn.functional.kl_div",
"torch.nn.functional.l1_loss",
"torch.nn.functional.mse_loss",
"torch.nn.functional.margin_ranking_loss",
"torch.nn.functional.multilabel_margin_loss",
"torch.nn.functional.multilabel_soft_margin_loss",
"torch.nn.functional.multi_margin_loss",
"torch.nn.functional.nll_loss",
"torch.nn.functional.smooth_l1_loss",
"torch.nn.functional.soft_margin_loss",
"torch.nn.functional.triplet_margin_loss",
"torch.nn.functional.pixel_shuffle",
"torch.nn.functional.pixel_shuffle",
"torch.nn.functional.pad",
"torch.nn.functional.interpolate",
"torch.nn.functional.upsample",
"torch.nn.functional.interpolate",
"torch.nn.functional.upsample_nearest",
"torch.nn.functional.interpolate",
"torch.nn.functional.upsample_bilinear",
"torch.nn.functional.interpolate",
"torch.nn.functional.grid_sample",
"torch.nn.functional.affine_grid"
]
\ No newline at end of file
[
"nn.Module",
"nn.CELU",
"nn.ELU",
"nn.GLU",
"nn.GELU",
"nn.Hardshrink",
"nn.Hardtanh",
"nn.LeakyReLU",
"nn.LogSigmoid",
"nn.LogSoftmax",
"nn.PReLU",
"nn.RReLU",
"nn.ReLU",
"nn.ReLU6",
"nn.SELU",
"nn.Sigmoid",
"nn.Softmax",
"nn.Softmax2d",
"nn.Softmin",
"nn.Softplus",
"nn.Softshrink",
"nn.Softsign",
"nn.Tanh",
"nn.Tanhshrink",
"nn.Threshold",
"nn.MultiheadAttention",
"nn.AdaptiveLogSoftmaxWithLoss",
"nn.BatchNorm1d",
"nn.BatchNorm2d",
"nn.BatchNorm3d",
"nn.SyncBatchNorm",
"nn.Container",
"nn.ModuleDict",
"nn.ModuleList",
"nn.ParameterDict",
"nn.ParameterList",
"nn.Sequential",
"nn.Conv1d",
"nn.Conv2d",
"nn.Conv3d",
"nn.ConvTranspose1d",
"nn.ConvTranspose2d",
"nn.ConvTranspose3d",
"nn.CosineSimilarity",
"nn.PairwiseDistance",
"nn.AlphaDropout",
"nn.Dropout",
"nn.Dropout2d",
"nn.Dropout3d",
"nn.FeatureAlphaDropout",
"nn.Fold",
"nn.Unfold",
"nn.InstanceNorm1d",
"nn.InstanceNorm2d",
"nn.InstanceNorm3d",
"nn.Bilinear",
"nn.Identity",
"nn.Linear",
"nn.BCELoss",
"nn.BCEWithLogitsLoss",
"nn.CTCLoss",
"nn.CosineEmbeddingLoss",
"nn.CrossEntropyLoss",
"nn.HingeEmbeddingLoss",
"nn.KLDivLoss",
"nn.L1Loss",
"nn.MSELoss",
"nn.MarginRankingLoss",
"nn.MultiLabelMarginLoss",
"nn.MultiLabelSoftMarginLoss",
"nn.MultiMarginLoss",
"nn.NLLLoss",
"nn.NLLLoss2d",
"nn.PoissonNLLLoss",
"nn.SmoothL1Loss",
"nn.SoftMarginLoss",
"nn.TripletMarginLoss",
"nn.Module",
"nn.CrossMapLRN2d",
"nn.GroupNorm",
"nn.LayerNorm",
"nn.LocalResponseNorm",
"nn.ConstantPad1d",
"nn.ConstantPad2d",
"nn.ConstantPad3d",
"nn.ReflectionPad1d",
"nn.ReflectionPad2d",
"nn.ReplicationPad1d",
"nn.ReplicationPad2d",
"nn.ReplicationPad3d",
"nn.ZeroPad2d",
"nn.PixelShuffle",
"nn.AdaptiveAvgPool1d",
"nn.AdaptiveAvgPool2d",
"nn.AdaptiveAvgPool3d",
"nn.AdaptiveMaxPool1d",
"nn.AdaptiveMaxPool2d",
"nn.AdaptiveMaxPool3d",
"nn.AvgPool1d",
"nn.AvgPool2d",
"nn.AvgPool3d",
"nn.FractionalMaxPool2d",
"nn.FractionalMaxPool3d",
"nn.LPPool1d",
"nn.LPPool2d",
"nn.MaxPool1d",
"nn.MaxPool2d",
"nn.MaxPool3d",
"nn.MaxUnpool1d",
"nn.MaxUnpool2d",
"nn.MaxUnpool3d",
"nn.GRU",
"nn.GRUCell",
"nn.LSTM",
"nn.LSTMCell",
"nn.RNN",
"nn.RNNBase",
"nn.RNNCell",
"nn.RNNCellBase",
"nn.Embedding",
"nn.EmbeddingBag",
"nn.Upsample",
"nn.UpsamplingBilinear2d",
"nn.UpsamplingNearest2d",
"nn.Transformer",
"nn.TransformerEncoder",
"nn.TransformerDecoder",
"nn.TransformerEncoderLayer",
"nn.TransformerDecoderLayer",
"nn.Parameter",
"nn.DataParallel"
]
\ No newline at end of file
[
".new_tensor",
".new_full",
".new_empty",
".new_ones",
".new_zeros",
".abs",
".abs_",
".acos",
".acos_",
".add",
".add_",
".addbmm",
".addbmm_",
".addcdiv",
".addcdiv_",
".addcmul",
".addcmul_",
".addmm",
".addmm_",
".addmv",
".addmv_",
".addr",
".addr_",
".allclose",
".angle",
".apply_",
".argmax",
".argmin",
".argsort",
".asin",
".asin_",
".as_strided",
".atan",
".atan2",
".atan2_",
".atan_",
".baddbmm",
".baddbmm_",
".bernoulli",
".bernoulli_",
".bfloat16",
".bincount",
".bitwise_not",
".bitwise_not_",
".bitwise_and",
".bitwise_and_",
".bitwise_or",
".bitwise_or_",
".bitwise_xor",
".bitwise_xor_",
".bmm",
".bool",
".byte",
".cauchy_",
".ceil",
".ceil_",
".char",
".cholesky",
".cholesky_inverse",
".cholesky_solve",
".chunk",
".clamp",
".clamp_",
".clone",
".contiguous",
".copy_",
".conj",
".cos",
".cos_",
".cosh",
".cosh_",
".cpu",
".cross",
".cuda",
".cummax",
".cummin",
".cumprod",
".cumsum",
".data_ptr",
".dequantize",
".det",
".dense_dim",
".diag",
".diag_embed",
".diagflat",
".diagonal",
".fill_diagonal_",
".digamma",
".digamma_",
".dim",
".dist",
".div",
".div_",
".dot",
".double",
".eig",
".element_size",
".eq",
".eq_",
".equal",
".erf",
".erf_",
".erfc",
".erfc_",
".erfinv",
".erfinv_",
".exp",
".exp_",
".expm1",
".expm1_",
".expand",
".expand_as",
".exponential_",
".fft",
".fill_",
".flatten",
".flip",
".float",
".floor",
".floor_",
".floor_divide",
".floor_divide_",
".fmod",
".fmod_",
".frac",
".frac_",
".gather",
".ge",
".ge_",
".geometric_",
".geqrf",
".ger",
".get_device",
".gt",
".gt_",
".half",
".hardshrink",
".histc",
".ifft",
".index_add_",
".index_add",
".index_copy_",
".index_copy",
".index_fill_",
".index_fill",
".index_put_",
".index_put",
".index_select",
".indices",
".int",
".int_repr",
".inverse",
".irfft",
".is_contiguous",
".is_complex",
".is_floating_point",
".is_pinned",
".is_set_to",
".is_shared",
".is_signed",
".item",
".kthvalue",
".le",
".le_",
".lerp",
".lerp_",
".lgamma",
".lgamma_",
".log",
".log_",
".logdet",
".log10",
".log10_",
".log1p",
".log1p_",
".log2",
".log2_",
".log_normal_",
".logsumexp",
".logical_and",
".logical_and_",
".logical_not",
".logical_not_",
".logical_or",
".logical_or_",
".logical_xor",
".logical_xor_",
".long",
".lstsq",
".lt",
".lt_",
".lu",
".lu_solve",
".map_",
".masked_scatter_",
".masked_scatter",
".masked_fill_",
".masked_fill",
".masked_select",
".matmul",
".matrix_power",
".max",
".mean",
".median",
".min",
".mm",
".mode",
".mul",
".mul_",
".multinomial",
".mv",
".mvlgamma",
".mvlgamma_",
".narrow",
".narrow_copy",
".ndimension",
".ne",
".ne_",
".neg",
".neg_",
".nelement",
".nonzero",
".norm",
".normal_",
".numel",
".numpy",
".orgqr",
".ormqr",
".permute",
".pin_memory",
".pinverse",
".polygamma",
".polygamma_",
".pow",
".pow_",
".prod",
".put_",
".qr",
".qscheme",
".q_scale",
".q_zero_point",
".q_per_channel_scales",
".q_per_channel_zero_points",
".q_per_channel_axis",
".random_",
".reciprocal",
".reciprocal_",
".record_stream",
".remainder",
".remainder_",
".renorm",
".renorm_",
".repeat",
".repeat_interleave",
".requires_grad_",
".reshape",
".reshape_as",
".resize_",
".resize_as_",
".rfft",
".roll",
".rot90",
".round",
".round_",
".rsqrt",
".rsqrt_",
".scatter",
".scatter_",
".scatter_add_",
".scatter_add",
".select",
".set_",
".share_memory_",
".short",
".sigmoid",
".sigmoid_",
".sign",
".sign_",
".sin",
".sin_",
".sinh",
".sinh_",
".size",
".slogdet",
".solve",
".sort",
".split",
".sparse_mask",
".sparse_dim",
".sqrt",
".sqrt_",
".square",
".square_",
".squeeze",
".squeeze_",
".std",
".stft",
".storage",
".storage_offset",
".storage_type",
".stride",
".sub",
".sub_",
".sum",
".sum_to_size",
".svd",
".symeig",
".t",
".t_",
".to",
".to_mkldnn",
".take",
".tan",
".tan_",
".tanh",
".tanh_",
".tolist",
".topk",
".to_sparse",
".trace",
".transpose",
".transpose_",
".triangular_solve",
".tril",
".tril_",
".triu",
".triu_",
".true_divide",
".true_divide_",
".trunc",
".trunc_",
".type",
".type_as",
".unbind",
".unfold",
".uniform_",
".unique",
".unique_consecutive",
".unsqueeze",
".unsqueeze_",
".values",
".var",
".view",
".view_as",
".where",
".zero_"
]
\ No newline at end of file
[
"torch.is_tensor",
"torch.is_storage",
"torch.is_complex",
"torch.is_floating_point",
"torch.set_default_dtype",
"torch.get_default_dtype",
"torch.set_default_tensor_type",
"torch.numel",
"torch.set_printoptions",
"torch.set_flush_denormal",
"torch.tensor",
"torch.sparse_coo_tensor",
"torch.as_tensor",
"torch.as_strided",
"torch.from_numpy",
"torch.zeros",
"torch.zeros_like",
"torch.ones",
"torch.ones_like",
"torch.arange",
"torch.range",
"torch.linspace",
"torch.logspace",
"torch.eye",
"torch.empty",
"torch.empty_like",
"torch.empty_strided",
"torch.full",
"torch.full_like",
"torch.quantize_per_tensor",
"torch.quantize_per_channel",
"torch.cat",
"torch.chunk",
"torch.gather",
"torch.index_select",
"torch.masked_select",
"torch.narrow",
"torch.nonzero",
"torch.reshape",
"torch.split",
"torch.squeeze",
"torch.stack",
"torch.t",
"torch.take",
"torch.transpose",
"torch.unbind",
"torch.unsqueeze",
"torch.where",
"torch._C.Generator",
"torch._C.Generator.device",
"torch._C.Generator.get_state",
"torch._C.Generator.initial_seed",
"torch._C.Generator.manual_seed",
"torch._C.Generator.seed",
"torch._C.Generator.set_state",
"torch.seed",
"torch.manual_seed",
"torch.initial_seed",
"torch.get_rng_state",
"torch.set_rng_state",
"torch.default_generator",
"torch.bernoulli",
"torch.multinomial",
"torch.normal",
"torch.poisson",
"torch.rand",
"torch.rand_like",
"torch.randint",
"torch.randint_like",
"torch.randn",
"torch.randn_like",
"torch.randperm",
"torch.quasirandom.SobolEngine",
"torch.quasirandom.SobolEngine.draw",
"torch.quasirandom.SobolEngine.fast_forward",
"torch.quasirandom.SobolEngine.reset",
"torch.save",
"torch.load",
"torch.get_num_threads",
"torch.set_num_threads",
"torch.get_num_interop_threads",
"torch.set_num_interop_threads",
"torch.no_grad",
"torch.enable_grad",
"torch.set_grad_enabled",
"torch.abs",
"torch.acos",
"torch.add",
"torch.addcdiv",
"torch.addcmul",
"torch.angle",
"torch.asin",
"torch.atan",
"torch.atan2",
"torch.bitwise_not",
"torch.bitwise_and",
"torch.bitwise_or",
"torch.bitwise_xor",
"torch.ceil",
"torch.clamp",
"torch.conj",
"torch.cos",
"torch.cosh",
"torch.div",
"torch.digamma",
"torch.erf",
"torch.erfc",
"torch.erfinv",
"torch.exp",
"torch.expm1",
"torch.floor",
"torch.floor_divide",
"torch.fmod",
"torch.frac",
"torch.imag",
"torch.lerp",
"torch.lgamma",
"torch.log",
"torch.log10",
"torch.log1p",
"torch.log2",
"torch.logical_and",
"torch.logical_not",
"torch.logical_or",
"torch.logical_xor",
"torch.mul",
"torch.mvlgamma",
"torch.neg",
"torch.polygamma",
"torch.pow",
"torch.real",
"torch.reciprocal",
"torch.remainder",
"torch.round",
"torch.rsqrt",
"torch.sigmoid",
"torch.sign",
"torch.sin",
"torch.sinh",
"torch.sqrt",
"torch.square",
"torch.tan",
"torch.tanh",
"torch.true_divide",
"torch.trunc",
"torch.argmax",
"torch.argmin",
"torch.dist",
"torch.logsumexp",
"torch.mean",
"torch.median",
"torch.mode",
"torch.norm",
"torch.prod",
"torch.std",
"torch.std_mean",
"torch.sum",
"torch.unique",
"torch.unique_consecutive",
"torch.var",
"torch.var_mean",
"torch.allclose",
"torch.argsort",
"torch.eq",
"torch.equal",
"torch.ge",
"torch.gt",
"torch.isfinite",
"torch.isinf",
"torch.isnan",
"torch.kthvalue",
"torch.le",
"torch.lt",
"torch.max",
"torch.min",
"torch.ne",
"torch.sort",
"torch.topk",
"torch.fft",
"torch.ifft",
"torch.rfft",
"torch.irfft",
"torch.stft",
"torch.bartlett_window",
"torch.blackman_window",
"torch.hamming_window",
"torch.hann_window",
"torch.bincount",
"torch.broadcast_tensors",
"torch.cartesian_prod",
"torch.cdist",
"torch.combinations",
"torch.cross",
"torch.cummax",
"torch.cummin",
"torch.cumprod",
"torch.cumsum",
"torch.diag",
"torch.diag_embed",
"torch.diagflat",
"torch.diagonal",
"torch.einsum",
"torch.flatten",
"torch.flip",
"torch.rot90",
"torch.histc",
"torch.meshgrid",
"torch.renorm",
"torch.repeat_interleave",
"torch.roll",
"torch.tensordot",
"torch.trace",
"torch.tril",
"torch.tril_indices",
"torch.triu",
"torch.triu_indices",
"torch.addbmm",
"torch.addmm",
"torch.addmv",
"torch.addr",
"torch.baddbmm",
"torch.bmm",
"torch.chain_matmul",
"torch.cholesky",
"torch.cholesky_inverse",
"torch.cholesky_solve",
"torch.dot",
"torch.eig",
"torch.geqrf",
"torch.ger",
"torch.inverse",
"torch.det",
"torch.logdet",
"torch.slogdet",
"torch.lstsq",
"torch.lu",
"torch.lu_solve",
"torch.lu_unpack",
"torch.matmul",
"torch.matrix_power",
"torch.matrix_rank",
"torch.mm",
"torch.mv",
"torch.orgqr",
"torch.ormqr",
"torch.pinverse",
"torch.qr",
"torch.solve",
"torch.svd",
"torch.svd_lowrank",
"torch.pca_lowrank",
"torch.symeig",
"torch.lobpcg",
"torch.trapz",
"torch.triangular_solve",
"torch.compiled_with_cxx11_abi",
"torch.result_type",
"torch.can_cast",
"torch.promote_types"
]
\ No newline at end of file
......@@ -2,6 +2,7 @@ attrdict>=2.0.1
Click>=7.0
Flask>=1.1.1
Flask-Cors>=3.0.8
google-pasta>=0.1.8
gunicorn>=19.9.0
itsdangerous>=1.1.0
Jinja2>=2.10.1
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册