提交 8687a4f7 编写于 作者: L lilongfei 提交者: quyongxiu

add torch2ms and delete dynamic

fix pylint

add comment for all func and parames has instruction

logging use logger

fix pylint

logger Upper case

copyright 2020

use realpath

one example in converter fdopen onefile

use google-pasta instead of astunparse and add in requirements_txt

use __base__.__name__ to judge if subclass of nn.Module

comment fix

fix review problem

fix pylint
上级 8301a7fb
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Create a logger."""
from mindinsight.utils.log import setup_logger
logger = setup_logger("mindconverter", "mindconverter")
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless REQUIRED by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""API config"""
import ast
from collections import OrderedDict
from functools import partial
import json
import os
import pasta
from mindinsight.mindconverter.enums import RequriedType
from mindinsight.mindconverter.common.log import logger
REQUIRED = RequriedType.REQUIRED.name
UNREQUIRED = RequriedType.UNREQUIRED.name
class APIPt:
"""Base API for args parse, and API for one frame."""
def __init__(self, name: str, params: OrderedDict):
self.name = name
self.params = OrderedDict()
for k, value in params.items():
self.params[k] = self.to_str(value)
@staticmethod
def to_str(value):
"""
Trans value to str.
Args:
value (Union[str,Number,int]): Each value for params of OrderedDict.
Returns:
str, str type of value.
"""
if value is REQUIRED:
return value
if isinstance(value, str):
return "'{}'".format(value)
return str(value)
def parse_args(self, call_name: str, args_str: str):
"""
Parse call_name and args_str.
Args:
call_name (str): str of the call function, etc.
args_str (str): str of args for function, which starts with '(' and end with ')'.
Returns:
OrderedDict, all args parsed.
Raises:
ValueError: If can not use ast to parse or the required parse node not type of ast.Call,
or the given args_str not valid.
"""
# expr is REQUIRED to meet (**) format
if not (len(args_str) >= 2 and args_str[0] == "(" and args_str[-1] == ")"):
raise ValueError('[{}] is think as args str, it should start with "(" and end with ")"'.format(args_str))
try:
ast_node = ast.parse("whatever_call_name" + args_str)
call_node = ast_node.body[0].value
if not isinstance(call_node, ast.Call):
raise ValueError('call name with args str [{}] not instance of ast.Call'.format(args_str))
except:
raise ValueError("can't parse code:\n{}".format(args_str))
# regard all actual parameter as one parameter
if len(self.params) == 1:
k = list(self.params.keys())[0]
if k.startswith('*'):
value = args_str[1:-1]
return OrderedDict([(k, value), ("call_name", call_name)])
args = OrderedDict()
# param which name not assigned
param_iter = iter(self.params.keys())
if len(call_node.args) > len(self.params):
raise ValueError('Parse args of torch in {}, but there is problems with params'.format(call_name))
for arg in call_node.args:
if isinstance(arg, ast.Starred):
logger.debug("Find *%s", arg.value.id)
args['*'] = arg.value.id
else:
# remove \n
args[next(param_iter)] = pasta.dump(arg).strip()
# params which name is assigned
for keyword in call_node.keywords:
if keyword.arg is None:
logger.info("Find **%s", keyword.value.id)
args['**'] = keyword.value.id
else:
# remove \n
args[keyword.arg] = pasta.dump(keyword.value).strip()
args["call_name"] = call_name
return args
class APIMs(APIPt):
"""API for MindSpore"""
def __init__(self, name: str, params: OrderedDict, p_attrs=None):
self.is_primitive = name.startswith('P.')
if self.is_primitive:
self.p_attrs = p_attrs if p_attrs else set()
super(APIMs, self).__init__(name, params)
def create_args(self, params_pt, args_pt, ms2pt_map, explicit_map):
"""
Create args for MindSpore according to other frame op info.
Args:
params_pt (OrderedDict): Params used for initialize function of APIPt.
args_pt (OrderedDict): Args parsed from APIPt.
ms2pt_map (dict): Dict of params mapping relation for ops between frames.
explicit_map(func): Function to generate mapping relation for ops between frames.
Returns:
OrderedDict, args for MindSpore.
"""
args = OrderedDict()
# traverse MindSpore's params
for k in self.params.keys():
# has relevant param? yes
if k in ms2pt_map:
if ms2pt_map[k] in args_pt:
# user assigned value
args[k] = args_pt[ms2pt_map[k]]
elif self.params[k] != params_pt[ms2pt_map[k]]:
# user didn't assigned value, but initial value different between 2 frames
args[k] = params_pt[ms2pt_map[k]]
# has relevant param? no
else:
# params forced to display
if k in explicit_map:
args[k] = explicit_map[k]
elif self.params[k] is REQUIRED:
args[k] = "<REQUIRED>"
# find * or ** in frame actual parameters
for star in ('*', '**'):
if star in args_pt:
args[star] = args_pt[star]
return args
class MappingHelper:
"""Mapping from one frame to another frame"""
def __init__(self, ms_api: APIMs, pt_api: APIPt, **kwargs):
ms2pt_mapping = kwargs.get('ms2pt_mapping')
gen_explicit_map = kwargs.get('gen_explicit_map')
export_key = kwargs.get('export_key')
if ms2pt_mapping is None:
ms2pt_mapping = {}
if gen_explicit_map is None:
gen_explicit_map = lambda params_pt, args_pt: {}
self.ms_api = ms_api
self.pt_api = pt_api
self.ms2pt_mapping = ms2pt_mapping
self.gen_explicit_map = gen_explicit_map
if export_key is not None:
self.export_key = export_key
else:
self.export_key = not ms_api.is_primitive
def gen_args_expr(self, args):
"""
Generate str assignment statement from given dict.
Args:
args (OrderedDict): Key, value pairs for assignment source.
Returns:
str, generated str.
"""
expr = ''
for key, value in args.items():
if expr:
expr += ', '
sym = '' if key in ('*', '**') else '='
if self.export_key:
expr += key + sym
expr += value
return expr
def gen_args_expr_for_p(self, args, p_attrs):
"""
Generate str assignment statement from given dict for primitive and not primitive.
Args:
args (OrderedDict): Key, value pairs for assignment source.
p_attrs (set): Exclusive params for operator.
Returns:
tuple, generated str for primitive, generated str for not primitive.
"""
args_attrs = OrderedDict([(k, v) for k, v in args.items() if k in p_attrs])
args_ios = OrderedDict([(k, v) for k, v in args.items() if k not in p_attrs])
return self.gen_args_expr(args_attrs), self.gen_args_expr(args_ios)
def convert(self, call_name_pt: str, args_str_pt: str):
"""
Convert code sentence to MindSpore code sentence.
Args:
call_name_pt (str): str of the call function, etc.
args_str_pt (str): str of args for function, which starts with '(' and end with ')'.
Returns:
str, converted code sentence for MindSpore.
"""
# all value for args_pt is str
args_pt = self.pt_api.parse_args(call_name_pt, args_str_pt)
# all value for args_ms is str
explicit_map = self.gen_explicit_map(self.pt_api.params, args_pt)
args_ms = self.ms_api.create_args(self.pt_api.params, args_pt, self.ms2pt_mapping, explicit_map)
if self.ms_api.is_primitive:
if self.pt_api.name == '.size' and 'idx' in args_pt:
args_expr = self.gen_args_expr(args_ms)
expr_ms = "%s()(%s)[%s]" % (self.ms_api.name, args_expr, args_pt['idx'])
else:
expr_attrs, expr_ios = self.gen_args_expr_for_p(args_ms, self.ms_api.p_attrs)
expr_ms = "%s(%s)(%s)" % (self.ms_api.name, expr_attrs, expr_ios)
else:
ms_expr = self.gen_args_expr(args_ms)
expr_ms = "%s(%s)" % (self.ms_api.name, ms_expr)
return expr_ms
def gen_explicit_map_nn_sequential(_, args_pt):
"""
Generate explicit_map for nn.Sequential.
Args:
args_pt (dict): Args for APIPt.
Returns:
dict, map between frames.
"""
args = args_pt['*args']
return {"*args": "[{}]".format(args)}
def gen_explicit_map_nn_maxpool2d(params_pt, args_pt):
"""
Generate explicit_map for nn.MaxPool2d.
Args:
params_pt (dict): Params for APIPt.
args_pt (dict): Args for APIPt.
Returns:
dict, map between frames.
"""
if 'padding' in args_pt:
padding = args_pt['padding']
else:
padding = params_pt['padding']
if padding.strip() in ("0", "(0,0)", "(0, 0)"):
pad_mode = "'valid'"
else:
pad_mode = "'same'"
return {"pad_mode": pad_mode}
def gen_explicit_map_f_max_pool2d(params_pt, args_pt):
"""
Generate explicit_map for F.MaxPool2d.
Args:
params_pt (dict): Params for APIPt.
args_pt (dict): Args for APIPt.
Returns:
dict, map between frames.
"""
if 'padding' in args_pt:
padding = args_pt['padding']
else:
padding = params_pt['padding']
if padding.strip() in ("0", "(0,0)", "(0, 0)"):
padding = "'valid'"
else:
padding = "'same'"
return {"padding": padding}
def gen_explicit_map_one_delta(params_pt, args_pt, k_ms, k_pt):
"""
Generate explicit_map for which include mapping relationship is `1 - k_ms = k_pt`.
Args:
params_pt (dict): Params for APIPt.
args_pt (dict): Args for APIPt.
Returns:
dict, map between frames.
"""
value = args_pt[k_pt] if k_pt in args_pt else params_pt[k_pt]
value = value.strip()
def is_number(string):
try:
float(string)
return True
except ValueError:
return False
if is_number(value):
return {k_ms: str(1 - float(value))}
return {k_ms: "1.0 - " + value}
def load_json_file(file_path):
"""
Load data from given json file path.
Args:
file_path (str): The file to load json data from.
Returns:
list, the list data stored in file_path.
"""
with open(file_path, 'r', encoding='utf-8') as file:
info = json.loads(file.read())
return info
# ---------------------------- mappings ----------------------------
NN_MAPPING = {
'nn.Sequential': MappingHelper(**{"ms_api": APIMs('nn.SequentialCell', OrderedDict([('*args', REQUIRED)])),
"pt_api": APIPt('nn.Sequential', OrderedDict([('*args', REQUIRED)])),
"gen_explicit_map": gen_explicit_map_nn_sequential,
"export_key": False
}),
'nn.Conv2d': MappingHelper(**{"ms_api": APIMs('nn.Conv2d', OrderedDict(in_channels=REQUIRED,
out_channels=REQUIRED,
kernel_size=REQUIRED,
stride=1,
pad_mode='same',
padding=0,
dilation=1,
group=1,
has_bias=False,
weight_init='normal',
bias_init='zeros')),
"pt_api": APIPt('nn.Conv2d', OrderedDict(in_channels=REQUIRED,
out_channels=REQUIRED,
kernel_size=REQUIRED,
stride=1,
padding=0,
dilation=1,
groups=1,
bias=True,
padding_mode='zeros')),
"ms2pt_mapping": {'in_channels': 'in_channels',
'out_channels': 'out_channels',
'kernel_size': 'kernel_size',
'stride': 'stride',
'padding': 'padding',
'dilation': 'dilation',
'group': 'groups',
'has_bias': 'bias'},
"gen_explicit_map": (lambda params_pt, args_pt: {"pad_mode": "'pad'"})
}),
'nn.BatchNorm2d': MappingHelper(**{"ms_api": APIMs('nn.BatchNorm2d', OrderedDict(num_features=REQUIRED,
eps=1e-5,
momentum=0.9,
affine=True,
gamma_init='ones',
beta_init='zeros',
moving_mean_init='zeros',
moving_var_init='ones',
use_batch_statistics=True)),
"pt_api": APIPt('nn.BatchNorm2d', OrderedDict(num_features=REQUIRED,
eps=1e-5,
momentum=0.1,
affine=True,
track_running_stats=True)),
"ms2pt_mapping": {"num_features": "num_features",
"eps": "eps",
"affine": "affine",
"use_batch_statistics": "track_running_stats"},
"gen_explicit_map": partial(gen_explicit_map_one_delta,
k_ms="momentum", k_pt="momentum")
}),
'nn.ReLU': MappingHelper(**{"ms_api": APIMs('nn.ReLU', OrderedDict()),
"pt_api": APIPt('nn.ReLU', OrderedDict(inplace=False)),
"ms2pt_mapping": {}}),
'nn.ReLU6': MappingHelper(**{"ms_api": APIMs('nn.ReLU6', OrderedDict()),
"pt_api": APIPt('nn.ReLU6', OrderedDict(inplace=False)),
"ms2pt_mapping": {}}),
'nn.Linear': MappingHelper(**{"ms_api": APIMs('nn.Dense', OrderedDict(in_channels=REQUIRED,
out_channels=REQUIRED,
weight_init='normal',
bias_init='zeros',
has_bias=True,
activation=None)),
"pt_api": APIPt('nn.Linear', OrderedDict(in_features=REQUIRED,
out_features=REQUIRED,
bias=True)),
"ms2pt_mapping": {"in_channels": "in_features",
"out_channels": "out_features",
"has_bias": "bias"}
}),
'nn.MaxPool2d': MappingHelper(**{"ms_api": APIMs('nn.MaxPool2d', OrderedDict(kernel_size=1,
stride=1,
pad_mode="valid")),
"pt_api": APIPt('nn.MaxPool2d', OrderedDict(kernel_size=REQUIRED,
stride=None,
padding=0,
dilation=1,
return_indices=False,
ceil_mode="False")),
"ms2pt_mapping": {"kernel_size": "kernel_size",
"stride": "stride"},
"gen_explicit_map": gen_explicit_map_nn_maxpool2d
}),
'nn.AvgPool2d': MappingHelper(**{"ms_api": APIMs('nn.AvgPool2d', OrderedDict(kernel_size=1,
stride=1,
pad_mode="valid")),
"pt_api": APIPt('nn.AvgPool2d', OrderedDict(kernel_size=REQUIRED,
stride=None,
padding=0,
dilation=1,
return_indices=False,
ceil_mode="False")),
"ms2pt_mapping": {"kernel_size": "kernel_size",
"stride": "stride"},
"gen_explicit_map": gen_explicit_map_nn_maxpool2d
}),
'nn.Dropout': MappingHelper(**{"ms_api": APIMs('nn.Dropout', OrderedDict(keep_prob=0.5,
seed0=0,
seed1=0,
dtype="mstype.float32")),
"pt_api": APIPt('nn.Dropout', OrderedDict(p=0.5,
inplace=False)),
"ms2pt_mapping": {"keep_prob": "p"},
"gen_explicit_map": partial(gen_explicit_map_one_delta,
k_ms="keep_prob", k_pt="p")
})
}
# set alias nn. = torch.nn.
NN_MAPPING.update({"torch." + k: v for k, v in NN_MAPPING.items()})
F_MAPPING = {
'F.relu': MappingHelper(**{"ms_api": APIMs('P.ReLU', OrderedDict(input=REQUIRED)),
"pt_api": APIPt('F.relu', OrderedDict(input=REQUIRED, inplace=False)),
"ms2pt_mapping": {"input": "input"},
}),
'F.relu6': MappingHelper(**{"ms_api": APIMs('P.ReLU6', OrderedDict(input=REQUIRED)),
"pt_api": APIPt('F.relu6', OrderedDict(input=REQUIRED, inplace=False)),
"ms2pt_mapping": {"input": "input"},
}),
'F.max_pool2d': MappingHelper(**{"ms_api": APIMs('P.MaxPool', OrderedDict(ksize=1,
strides=1,
padding="valid",
input=REQUIRED),
p_attrs={"ksize", "strides", "padding"}),
"pt_api": APIPt('F.max_pool2d', OrderedDict(input=REQUIRED,
kernel_size=REQUIRED,
stride=None,
padding=0,
dilation=1,
ceil_mode=False,
return_indices=False)),
"ms2pt_mapping": {"ksize": "kernel_size",
"strides": "stride",
"input": "input",
},
"gen_explicit_map": gen_explicit_map_f_max_pool2d
}),
'F.avg_pool2d': MappingHelper(**{"ms_api": APIMs('P.AvgPool', OrderedDict(ksize=1,
strides=1,
padding="valid",
input=REQUIRED),
p_attrs={"ksize", "strides", "padding"}),
"pt_api": APIPt('F.avg_pool2d', OrderedDict(input=REQUIRED,
kernel_size=REQUIRED,
stride=None,
padding=0,
dilation=1,
ceil_mode=False,
return_indices=False)),
"ms2pt_mapping": {"ksize": "kernel_size",
"strides": "stride",
"input": "input",
},
"gen_explicit_map": gen_explicit_map_f_max_pool2d
}),
}
# set alias F = nn.functional = torch.nn.functional
nn_functional_d = {"nn.functional." + k[2:]: v for k, v in F_MAPPING.items()}
torch_nn_functional_d = {"torch.nn.functional." + k[2:]: v for k, v in F_MAPPING.items()}
F_MAPPING.update(nn_functional_d)
F_MAPPING.update(torch_nn_functional_d)
TORCH_DOT_MAPPING = {
'torch.flatten': MappingHelper(**{"ms_api": APIMs('P.Flatten', OrderedDict(input=REQUIRED)),
"pt_api": APIPt('torch.flatten', OrderedDict(input=REQUIRED,
start_dim=0,
end_dim=-1)),
"ms2pt_mapping": {"input": "input"}
}),
'torch.cat': MappingHelper(**{"ms_api": APIMs('P.Concat',
OrderedDict(axis=0, input=REQUIRED),
p_attrs={"axis"}),
"pt_api": APIPt('torch.flatten', OrderedDict(tensors=REQUIRED, dim=0, out=None)),
"ms2pt_mapping": {"input": "tensors",
"axis": "dim"}
}),
}
TENSOR_DOT_MAPPING = {
'.view': MappingHelper(**{"ms_api": APIMs('P.Reshape', OrderedDict(x=REQUIRED, shape=REQUIRED)),
"pt_api": APIPt('.view', OrderedDict([('*shape', REQUIRED)])),
"ms2pt_mapping": {"x": "call_name"},
"gen_explicit_map": (lambda params_pt, args_pt: {"shape": "(" + args_pt["*shape"] + ",)"})
}),
'.size': MappingHelper(**{"ms_api": APIMs('P.Shape', OrderedDict(x=REQUIRED)),
"pt_api": APIPt('.size', OrderedDict([('idx', REQUIRED)])),
"ms2pt_mapping": {"x": "call_name"}
}),
'.flatten': MappingHelper(**{"ms_api": APIMs('P.Flatten', OrderedDict(input=REQUIRED)),
"pt_api": APIPt('.flatten', OrderedDict(start_dim=0,
end_dim=-1)),
"ms2pt_mapping": {"input": "call_name"}
}),
'.reshape': MappingHelper(**{"ms_api": APIMs('P.Reshape', OrderedDict(x=REQUIRED, shape=REQUIRED)),
"pt_api": APIPt('.reshape', OrderedDict([('*shape', REQUIRED)])),
"ms2pt_mapping": {"x": "call_name"},
"gen_explicit_map": (
lambda params_pt, args_pt: {"shape": "(" + args_pt["*shape"] + ",)"})
}),
'.mean': MappingHelper(**{"ms_api": APIMs('P.ReduceMean', OrderedDict(keep_dims=False,
input=REQUIRED,
axis=()),
p_attrs={"keep_dims"}),
"pt_api": APIPt('.mean', OrderedDict(dim=None,
keepdim=False)),
"ms2pt_mapping": {"keep_dims": "keepdim",
"axis": "dim",
"input": "call_name"},
}),
'.squeeze': MappingHelper(**{"ms_api": APIMs('P.ReduceMean', OrderedDict(input=REQUIRED,
axis=()),
p_attrs={"axis"}),
"pt_api": APIPt('.squeeze', OrderedDict(dim=None)),
"ms2pt_mapping": {"axis": "dim",
"input": "call_name"},
}),
}
ALL_MAPPING = {**NN_MAPPING, **F_MAPPING, **TORCH_DOT_MAPPING, **TENSOR_DOT_MAPPING}
# ---------------------------- api list support or not support ----------------------------
NN_LIST_PATH = os.path.realpath(os.path.join(os.path.dirname(__file__), 'nn_list.json'))
NN_LIST = load_json_file(NN_LIST_PATH)
# set alias nn. = torch.nn.
NN_LIST += ["torch." + name for name in NN_LIST]
NN_SUPPORTED = [x for x in NN_LIST if x in ALL_MAPPING]
NN_UNSUPPORTED = [x for x in NN_LIST if x not in ALL_MAPPING]
F_LIST_PATH = os.path.realpath(os.path.join(os.path.dirname(__file__), 'f_list.json'))
F_LIST = load_json_file(F_LIST_PATH)
# set alias F = nn.functional = torch.nn.functional
F_LIST += ["F." + name[len("torch.nn.functional."):] for name in F_LIST] + \
[name[len("torch."):] for name in F_LIST]
F_SUPPORTED = [x for x in F_LIST if x in ALL_MAPPING]
F_UNSUPPORTED = [x for x in F_LIST if x not in ALL_MAPPING]
TORCH_DOT_LIST_PATH = os.path.realpath(os.path.join(os.path.dirname(__file__), 'torch_dot_list.json'))
TORCH_DOT_LIST = load_json_file(TORCH_DOT_LIST_PATH)
TORCH_DOT_SUPPORTED = [x for x in TORCH_DOT_LIST if x in ALL_MAPPING]
TORCH_DOT_UNSUPPORTED = [x for x in TORCH_DOT_LIST if x not in ALL_MAPPING]
TENSOR_DOT_LIST_PATH = os.path.realpath(os.path.join(os.path.dirname(__file__), 'tensor_dot_list.json'))
TENSOR_DOT_LIST = load_json_file(TENSOR_DOT_LIST_PATH)
TENSOR_DOT_SUPPORTED = [x for x in TENSOR_DOT_LIST if x in ALL_MAPPING]
TENSOR_DOT_UNSUPPORTED = [x for x in TENSOR_DOT_LIST if x not in ALL_MAPPING]
ALL_2P_LIST = F_LIST + TORCH_DOT_LIST + TENSOR_DOT_LIST
ALL_TORCH_APIS = NN_LIST + F_LIST + TORCH_DOT_LIST + TENSOR_DOT_LIST
ALL_SUPPORTED = NN_SUPPORTED + F_SUPPORTED + TORCH_DOT_SUPPORTED + TENSOR_DOT_SUPPORTED
ALL_UNSUPPORTED = NN_UNSUPPORTED + F_UNSUPPORTED + TORCH_DOT_UNSUPPORTED + TENSOR_DOT_UNSUPPORTED
UNSUPPORTED_WARN_INFOS = {
"nn.AdaptiveAvgPool2d": "maybe could convert to P.ReduceMean",
"F.adaptive_avg_pool2d": "maybe could convert to P.ReduceMean",
"F.dropout": "please use nn.Dropout in __init__()",
}
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless REQUIRED by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""main module"""
import inspect
import copy
import importlib
import os
import stat
from mindinsight.mindconverter.config import ALL_MAPPING
from mindinsight.mindconverter.config import NN_LIST
from mindinsight.mindconverter.config import ALL_TORCH_APIS
from mindinsight.mindconverter.config import ALL_2P_LIST
from mindinsight.mindconverter.config import UNSUPPORTED_WARN_INFOS
from mindinsight.mindconverter.common.log import logger
def is_local_defined(obj, member):
"""
Check if obj and member are both defined in the same source file.
Args:
obj (Union[object, module]): A module or a class.
member (func): A function of obj.
Returns:
bool, True or False.
"""
srcfile = inspect.getsourcefile(obj)
return inspect.getsourcefile(member) == srcfile
def is_valid_module(obj, member):
"""
Check if obj and member defined in same source file and member is inherited from torch.nn.Module.
Args:
obj (Union[object, module]): A module or a class.
member (func): A function.
Returns:
bool, True or False.
"""
return inspect.isclass(member) and (member.__base__.__name__ == 'Module') and is_local_defined(obj, member)
def is_valid_function(obj, member):
"""
Check if member is function and defined in the file same as obj.
Args:
obj (Union[object, module]: The obj.
member (func): The func.
Returns:
bool, True or False.
"""
return inspect.isfunction(member) and is_local_defined(obj, member)
def find_left_parentheses(string, right):
"""
Find index of the first left parenthesis.
Args:
string (str): A line of code.
right (int): Max index of string, same as `len(string) -1`.
Returns:
int, index of the first parenthesis.
Raises:
ValueError: If line of code doesn't contain any pair of `()` or `(` and `)` are not paired.
"""
if string[right] != ')':
raise ValueError('code [{}] at index {} not ")".'.format(string, right))
stack = []
for i in range(right, -1, -1):
if string[i] == ')':
stack.append(')')
elif string[i] == '(':
stack.pop()
if not stack:
return i
raise ValueError("{} should contain ()".format(string))
def find_right_parentheses(string, left):
"""
Find first index of right parenthesis which make all left parenthesis make sense.
Args:
string (str): A line of code.
left (int): Start index of string to find from.
Returns:
int, index of the found right parenthesis.
Raises:
ValueError: If line of code doesn't contain any pair of `()` or `(` and `)` are not paired.
"""
stack = []
for i in range(left, len(string)):
if string[i] == '(':
stack.append('(')
elif string[i] == ')':
stack.pop()
if not stack:
return i
raise ValueError("{} should contain ()".format(string))
def get_call_name(code, end):
"""
Traverse code in a reversed function from index end and get the call name and start index of the call name, if call
name not found, return a null character string and -1
Args:
code (str): The str of code to find from.
end (int): Start index to find.
Returns:
str, founded api name if found, else a null character string.
int, start index of founded api name, -1 if api name not found
"""
stack = []
for i in range(end - 1, -1, -1):
if code[i] in ["(", "[", "{"]:
if stack:
stack.pop()
else:
return code[i + 1:end], i + 1
elif code[i] in [")", "]", "}"]:
stack.append(code[i])
elif stack:
continue
elif not (code[i].isalpha() or code[i].isdigit() or code[i] == '_' or code[i] == '.'):
return code[i + 1:end], i + 1
return "", -1
def convert_api(code, start, api_name=""):
"""
Convert api_name in code to MindSpore api with start as a start index, if api_name is a python api, code will not
convert.
Args:
code (str): The str code to convert.
start (int): The index of code to start convert from.
api_name (str): The api name to convert.
Returns:
str, the converted code.
int, index of converted api_name in code.
"""
# handle format like .shape(
if api_name.startswith('.'):
call_name, new_start = get_call_name(code, start)
if start == -1 or call_name == "self":
return code, start + 1
else:
call_name = api_name
new_start = start
# find full api expected to be converted. eg:expr="nn.Conv2d(1,2,3)" args_str="(1,2,3)"
left = code.find("(", start)
if left == -1:
raise ValueError('"(" not found, {} should work with "("'.format(call_name))
right = find_right_parentheses(code, left)
end = right
expr = code[start:end + 1]
args_str = code[left:right + 1]
map_helper = ALL_MAPPING[api_name]
new_expr = map_helper.convert(call_name, args_str)
next_newline = code.find("\n", end + 1)
fill_num = (expr.count("\n") - new_expr.count("\n"))
if next_newline != -1:
code = code[:new_start] + new_expr + code[end + 1:next_newline] + ("\n" * fill_num) + code[next_newline:]
else:
code = code[:new_start] + new_expr + ")" + ("\n" * fill_num) + code[end + 2:]
return code, start + len(map_helper.ms_api.name)
def find_api(code, i, is_forward):
"""
Find api name from code with a start index i, check api name ok with a is_forward condition.
Args:
code (str): The code from which to find api name.
i (int): The start index to find.
is_forward (bool): Check if the found api name ok.
Returns:
str, api name if find api name and check ok with is_forward condition, else a null character string.
"""
if code[i:].startswith("nn.") \
or code[i:].startswith("F.") \
or code[i:].startswith("torch.") \
or code[i:].startswith('.'):
j = code.find('(', i)
if j != -1 and code[i:j] in ALL_TORCH_APIS:
api_name = code[i:j]
if (not is_forward and api_name in NN_LIST) or (is_forward and api_name in ALL_2P_LIST):
return api_name
return ""
def convert_function(fun_name, fun, is_forward):
"""
Convert a PyTorch function into MindSpore function.
Args:
fun_name (str): The str of function name.
fun (func): The function to convert.
is_forward (bool): If the function is defined in forward function in nn.Module in torch.
Returns:
dict, old code and converted code map if convert happens, else {}.
"""
_, line_no = inspect.getsourcelines(fun)
logger.info("Line %3d: start converting function %s()", line_no, fun_name)
code = inspect.getsource(fun)
code_saved = copy.copy(code)
i = 0
while i < len(code):
api_name = find_api(code, i, is_forward)
if api_name:
line_no1 = line_no + code[:i].count('\n')
if api_name in ALL_MAPPING:
logger.info("Line %3d start converting API: %s", line_no1, api_name)
code, i = convert_api(code, i, api_name)
continue
warn_info = ". " + UNSUPPORTED_WARN_INFOS[api_name] if api_name in UNSUPPORTED_WARN_INFOS else ""
logger.warning("Line %3d: found unsupported API: %s%s", line_no1, api_name, warn_info)
i += 1
return {code_saved: code} if code_saved != code else {}
def judge_forward(name, forward_list):
"""
Check if function is a forward function.
Args:
name (str): The function name.
forward_list (set): A set of forward function.
Returns:
bool, True or False
"""
is_forward = name in forward_list or name.split(".")[-1] == "forward"
if is_forward:
logger.debug("%s is a forward function", name)
return is_forward
def convert_module(module_name, module, forward_list):
"""
Convert a PyTorch module code into MindSpore module code.
Args:
module_name (str): The module's name.
module (module): The module to convert.
forward_list (set): A set of forward function.
Returns:
dict, map of old code and converted code.
"""
_, line_no = inspect.getsourcelines(module)
logger.info("Line {:3d}: start converting nn.Module {}".format(line_no, module_name))
mapped = {}
for name, member in inspect.getmembers(module):
if is_valid_function(module, member):
is_forward = judge_forward("{}.{}".format(module_name, name), forward_list)
mapped.update(convert_function(name, member, is_forward))
return mapped
def get_mapping(import_mod, forward_list):
"""
Convert code of a module and get mapping of old code and convert code.
Args:
import_mod (module): The module to convert.
forward_list (set): A set of forward function.
Returns:
dict, mapping for old code and converted code of the module
"""
mapping = {}
tasks = []
for name, member in inspect.getmembers(import_mod):
if is_valid_module(import_mod, member):
_, line_no = inspect.getsourcelines(member)
tasks.append((line_no, convert_module, (name, member, forward_list)))
elif is_valid_function(import_mod, member):
_, line_no = inspect.getsourcelines(member)
is_forward = judge_forward("{}.{}".format(import_mod, name), forward_list)
tasks.append((line_no, convert_function, (name, member, is_forward)))
tasks.sort()
for _, convert_fun, args in tasks:
mapping.update(convert_fun(*args))
return mapping
def convert(import_name, nn_module):
"""
The entrance for convert a module's code, code converted will be write to file called out.py.
Args:
import_name (str): The module from which to import the module to convert.
nn_module (str): Name of the module to convert.
"""
logger.info("Start converting %s.%s", import_name, nn_module)
import_mod = importlib.import_module(import_name)
forward_list = set()
logger.debug("Forward_list: %s", forward_list)
# replace python function under nn.Modlue
mapping = get_mapping(import_mod, forward_list)
code = inspect.getsource(import_mod)
for key, value in mapping.items():
code = code.replace(key, value)
code = 'import mindspore.ops.operations as P\n' + code
code = 'import mindspore.nn as nn\n' + code
code = 'import mindspore\n' + code
code = code.replace('import torch', '# import torch')
code = code.replace('from torch', '# from torch')
code = code.replace('(nn.Module):', '(nn.Cell):')
code = code.replace('forward(', 'construct(')
code = code.replace('nn.Linear', 'nn.Dense')
code = code.replace('(nn.Sequential)', '(nn.SequentialCell)')
code = code.replace('nn.init.', 'pass # nn.init.')
flags = os.O_WRONLY | os.O_CREAT | os.O_EXCL
modes = stat.S_IWUSR | stat.S_IRUSR
with os.fdopen(os.open('out.py', flags, modes), 'w') as file:
file.write(code)
logger.info("Convert success. Result is wrote to out.py\n")
if __name__ == '__main__':
convert('torchvision.models.resnet', 'resnet18')
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Enums."""
from enum import Enum
class RequriedType(Enum):
"""If param is required"""
REQUIRED = 1
UNREQUIRED = 2
[
"torch.nn.functional.conv1d",
"torch.nn.functional.conv2d",
"torch.nn.functional.conv3d",
"torch.nn.functional.conv_transpose1d",
"torch.nn.functional.conv_transpose2d",
"torch.nn.functional.conv_transpose3d",
"torch.nn.functional.unfold",
"torch.nn.functional.fold",
"torch.nn.functional.avg_pool1d",
"torch.nn.functional.avg_pool2d",
"torch.nn.functional.avg_pool3d",
"torch.nn.functional.max_pool1d",
"torch.nn.functional.max_pool2d",
"torch.nn.functional.max_pool3d",
"torch.nn.functional.max_unpool1d",
"torch.nn.functional.max_unpool2d",
"torch.nn.functional.max_unpool3d",
"torch.nn.functional.lp_pool1d",
"torch.nn.functional.lp_pool2d",
"torch.nn.functional.adaptive_max_pool1d",
"torch.nn.functional.adaptive_max_pool2d",
"torch.nn.functional.adaptive_max_pool3d",
"torch.nn.functional.adaptive_avg_pool1d",
"torch.nn.functional.adaptive_avg_pool2d",
"torch.nn.functional.adaptive_avg_pool3d",
"torch.nn.functional.threshold",
"torch.nn.functional.threshold_",
"torch.nn.functional.relu",
"torch.nn.functional.relu_",
"torch.nn.functional.hardtanh",
"torch.nn.functional.hardtanh_",
"torch.nn.functional.relu6",
"torch.nn.functional.elu",
"torch.nn.functional.elu_",
"torch.nn.functional.selu",
"torch.nn.functional.celu",
"torch.nn.functional.leaky_relu",
"torch.nn.functional.leaky_relu_",
"torch.nn.functional.prelu",
"torch.nn.functional.rrelu",
"torch.nn.functional.rrelu_",
"torch.nn.functional.glu",
"torch.nn.functional.gelu",
"torch.nn.functional.logsigmoid",
"torch.nn.functional.hardshrink",
"torch.nn.functional.tanhshrink",
"torch.nn.functional.softsign",
"torch.nn.functional.softplus",
"torch.nn.functional.softmin",
"torch.nn.functional.softmax",
"torch.nn.functional.softshrink",
"torch.nn.functional.gumbel_softmax",
"torch.nn.functional.log_softmax",
"torch.nn.functional.tanh",
"torch.nn.functional.sigmoid",
"torch.nn.functional.batch_norm",
"torch.nn.functional.instance_norm",
"torch.nn.functional.layer_norm",
"torch.nn.functional.local_response_norm",
"torch.nn.functional.normalize",
"torch.nn.functional.linear",
"torch.nn.functional.bilinear",
"torch.nn.functional.dropout",
"torch.nn.functional.alpha_dropout",
"torch.nn.functional.dropout2d",
"torch.nn.functional.dropout3d",
"torch.nn.functional.embedding",
"torch.nn.functional.embedding_bag",
"torch.nn.functional.one_hot",
"torch.nn.functional.pairwise_distance",
"torch.nn.functional.cosine_similarity",
"torch.nn.functional.pdist",
"torch.nn.functional.binary_cross_entropy",
"torch.nn.functional.binary_cross_entropy_with_logits",
"torch.nn.functional.poisson_nll_loss",
"torch.nn.functional.cosine_embedding_loss",
"torch.nn.functional.cross_entropy",
"torch.nn.functional.ctc_loss",
"torch.nn.functional.log_softmax",
"torch.nn.functional.hinge_embedding_loss",
"torch.nn.functional.kl_div",
"torch.nn.functional.l1_loss",
"torch.nn.functional.mse_loss",
"torch.nn.functional.margin_ranking_loss",
"torch.nn.functional.multilabel_margin_loss",
"torch.nn.functional.multilabel_soft_margin_loss",
"torch.nn.functional.multi_margin_loss",
"torch.nn.functional.nll_loss",
"torch.nn.functional.smooth_l1_loss",
"torch.nn.functional.soft_margin_loss",
"torch.nn.functional.triplet_margin_loss",
"torch.nn.functional.pixel_shuffle",
"torch.nn.functional.pixel_shuffle",
"torch.nn.functional.pad",
"torch.nn.functional.interpolate",
"torch.nn.functional.upsample",
"torch.nn.functional.interpolate",
"torch.nn.functional.upsample_nearest",
"torch.nn.functional.interpolate",
"torch.nn.functional.upsample_bilinear",
"torch.nn.functional.interpolate",
"torch.nn.functional.grid_sample",
"torch.nn.functional.affine_grid"
]
\ No newline at end of file
[
"nn.Module",
"nn.CELU",
"nn.ELU",
"nn.GLU",
"nn.GELU",
"nn.Hardshrink",
"nn.Hardtanh",
"nn.LeakyReLU",
"nn.LogSigmoid",
"nn.LogSoftmax",
"nn.PReLU",
"nn.RReLU",
"nn.ReLU",
"nn.ReLU6",
"nn.SELU",
"nn.Sigmoid",
"nn.Softmax",
"nn.Softmax2d",
"nn.Softmin",
"nn.Softplus",
"nn.Softshrink",
"nn.Softsign",
"nn.Tanh",
"nn.Tanhshrink",
"nn.Threshold",
"nn.MultiheadAttention",
"nn.AdaptiveLogSoftmaxWithLoss",
"nn.BatchNorm1d",
"nn.BatchNorm2d",
"nn.BatchNorm3d",
"nn.SyncBatchNorm",
"nn.Container",
"nn.ModuleDict",
"nn.ModuleList",
"nn.ParameterDict",
"nn.ParameterList",
"nn.Sequential",
"nn.Conv1d",
"nn.Conv2d",
"nn.Conv3d",
"nn.ConvTranspose1d",
"nn.ConvTranspose2d",
"nn.ConvTranspose3d",
"nn.CosineSimilarity",
"nn.PairwiseDistance",
"nn.AlphaDropout",
"nn.Dropout",
"nn.Dropout2d",
"nn.Dropout3d",
"nn.FeatureAlphaDropout",
"nn.Fold",
"nn.Unfold",
"nn.InstanceNorm1d",
"nn.InstanceNorm2d",
"nn.InstanceNorm3d",
"nn.Bilinear",
"nn.Identity",
"nn.Linear",
"nn.BCELoss",
"nn.BCEWithLogitsLoss",
"nn.CTCLoss",
"nn.CosineEmbeddingLoss",
"nn.CrossEntropyLoss",
"nn.HingeEmbeddingLoss",
"nn.KLDivLoss",
"nn.L1Loss",
"nn.MSELoss",
"nn.MarginRankingLoss",
"nn.MultiLabelMarginLoss",
"nn.MultiLabelSoftMarginLoss",
"nn.MultiMarginLoss",
"nn.NLLLoss",
"nn.NLLLoss2d",
"nn.PoissonNLLLoss",
"nn.SmoothL1Loss",
"nn.SoftMarginLoss",
"nn.TripletMarginLoss",
"nn.Module",
"nn.CrossMapLRN2d",
"nn.GroupNorm",
"nn.LayerNorm",
"nn.LocalResponseNorm",
"nn.ConstantPad1d",
"nn.ConstantPad2d",
"nn.ConstantPad3d",
"nn.ReflectionPad1d",
"nn.ReflectionPad2d",
"nn.ReplicationPad1d",
"nn.ReplicationPad2d",
"nn.ReplicationPad3d",
"nn.ZeroPad2d",
"nn.PixelShuffle",
"nn.AdaptiveAvgPool1d",
"nn.AdaptiveAvgPool2d",
"nn.AdaptiveAvgPool3d",
"nn.AdaptiveMaxPool1d",
"nn.AdaptiveMaxPool2d",
"nn.AdaptiveMaxPool3d",
"nn.AvgPool1d",
"nn.AvgPool2d",
"nn.AvgPool3d",
"nn.FractionalMaxPool2d",
"nn.FractionalMaxPool3d",
"nn.LPPool1d",
"nn.LPPool2d",
"nn.MaxPool1d",
"nn.MaxPool2d",
"nn.MaxPool3d",
"nn.MaxUnpool1d",
"nn.MaxUnpool2d",
"nn.MaxUnpool3d",
"nn.GRU",
"nn.GRUCell",
"nn.LSTM",
"nn.LSTMCell",
"nn.RNN",
"nn.RNNBase",
"nn.RNNCell",
"nn.RNNCellBase",
"nn.Embedding",
"nn.EmbeddingBag",
"nn.Upsample",
"nn.UpsamplingBilinear2d",
"nn.UpsamplingNearest2d",
"nn.Transformer",
"nn.TransformerEncoder",
"nn.TransformerDecoder",
"nn.TransformerEncoderLayer",
"nn.TransformerDecoderLayer",
"nn.Parameter",
"nn.DataParallel"
]
\ No newline at end of file
[
".new_tensor",
".new_full",
".new_empty",
".new_ones",
".new_zeros",
".abs",
".abs_",
".acos",
".acos_",
".add",
".add_",
".addbmm",
".addbmm_",
".addcdiv",
".addcdiv_",
".addcmul",
".addcmul_",
".addmm",
".addmm_",
".addmv",
".addmv_",
".addr",
".addr_",
".allclose",
".angle",
".apply_",
".argmax",
".argmin",
".argsort",
".asin",
".asin_",
".as_strided",
".atan",
".atan2",
".atan2_",
".atan_",
".baddbmm",
".baddbmm_",
".bernoulli",
".bernoulli_",
".bfloat16",
".bincount",
".bitwise_not",
".bitwise_not_",
".bitwise_and",
".bitwise_and_",
".bitwise_or",
".bitwise_or_",
".bitwise_xor",
".bitwise_xor_",
".bmm",
".bool",
".byte",
".cauchy_",
".ceil",
".ceil_",
".char",
".cholesky",
".cholesky_inverse",
".cholesky_solve",
".chunk",
".clamp",
".clamp_",
".clone",
".contiguous",
".copy_",
".conj",
".cos",
".cos_",
".cosh",
".cosh_",
".cpu",
".cross",
".cuda",
".cummax",
".cummin",
".cumprod",
".cumsum",
".data_ptr",
".dequantize",
".det",
".dense_dim",
".diag",
".diag_embed",
".diagflat",
".diagonal",
".fill_diagonal_",
".digamma",
".digamma_",
".dim",
".dist",
".div",
".div_",
".dot",
".double",
".eig",
".element_size",
".eq",
".eq_",
".equal",
".erf",
".erf_",
".erfc",
".erfc_",
".erfinv",
".erfinv_",
".exp",
".exp_",
".expm1",
".expm1_",
".expand",
".expand_as",
".exponential_",
".fft",
".fill_",
".flatten",
".flip",
".float",
".floor",
".floor_",
".floor_divide",
".floor_divide_",
".fmod",
".fmod_",
".frac",
".frac_",
".gather",
".ge",
".ge_",
".geometric_",
".geqrf",
".ger",
".get_device",
".gt",
".gt_",
".half",
".hardshrink",
".histc",
".ifft",
".index_add_",
".index_add",
".index_copy_",
".index_copy",
".index_fill_",
".index_fill",
".index_put_",
".index_put",
".index_select",
".indices",
".int",
".int_repr",
".inverse",
".irfft",
".is_contiguous",
".is_complex",
".is_floating_point",
".is_pinned",
".is_set_to",
".is_shared",
".is_signed",
".item",
".kthvalue",
".le",
".le_",
".lerp",
".lerp_",
".lgamma",
".lgamma_",
".log",
".log_",
".logdet",
".log10",
".log10_",
".log1p",
".log1p_",
".log2",
".log2_",
".log_normal_",
".logsumexp",
".logical_and",
".logical_and_",
".logical_not",
".logical_not_",
".logical_or",
".logical_or_",
".logical_xor",
".logical_xor_",
".long",
".lstsq",
".lt",
".lt_",
".lu",
".lu_solve",
".map_",
".masked_scatter_",
".masked_scatter",
".masked_fill_",
".masked_fill",
".masked_select",
".matmul",
".matrix_power",
".max",
".mean",
".median",
".min",
".mm",
".mode",
".mul",
".mul_",
".multinomial",
".mv",
".mvlgamma",
".mvlgamma_",
".narrow",
".narrow_copy",
".ndimension",
".ne",
".ne_",
".neg",
".neg_",
".nelement",
".nonzero",
".norm",
".normal_",
".numel",
".numpy",
".orgqr",
".ormqr",
".permute",
".pin_memory",
".pinverse",
".polygamma",
".polygamma_",
".pow",
".pow_",
".prod",
".put_",
".qr",
".qscheme",
".q_scale",
".q_zero_point",
".q_per_channel_scales",
".q_per_channel_zero_points",
".q_per_channel_axis",
".random_",
".reciprocal",
".reciprocal_",
".record_stream",
".remainder",
".remainder_",
".renorm",
".renorm_",
".repeat",
".repeat_interleave",
".requires_grad_",
".reshape",
".reshape_as",
".resize_",
".resize_as_",
".rfft",
".roll",
".rot90",
".round",
".round_",
".rsqrt",
".rsqrt_",
".scatter",
".scatter_",
".scatter_add_",
".scatter_add",
".select",
".set_",
".share_memory_",
".short",
".sigmoid",
".sigmoid_",
".sign",
".sign_",
".sin",
".sin_",
".sinh",
".sinh_",
".size",
".slogdet",
".solve",
".sort",
".split",
".sparse_mask",
".sparse_dim",
".sqrt",
".sqrt_",
".square",
".square_",
".squeeze",
".squeeze_",
".std",
".stft",
".storage",
".storage_offset",
".storage_type",
".stride",
".sub",
".sub_",
".sum",
".sum_to_size",
".svd",
".symeig",
".t",
".t_",
".to",
".to_mkldnn",
".take",
".tan",
".tan_",
".tanh",
".tanh_",
".tolist",
".topk",
".to_sparse",
".trace",
".transpose",
".transpose_",
".triangular_solve",
".tril",
".tril_",
".triu",
".triu_",
".true_divide",
".true_divide_",
".trunc",
".trunc_",
".type",
".type_as",
".unbind",
".unfold",
".uniform_",
".unique",
".unique_consecutive",
".unsqueeze",
".unsqueeze_",
".values",
".var",
".view",
".view_as",
".where",
".zero_"
]
\ No newline at end of file
[
"torch.is_tensor",
"torch.is_storage",
"torch.is_complex",
"torch.is_floating_point",
"torch.set_default_dtype",
"torch.get_default_dtype",
"torch.set_default_tensor_type",
"torch.numel",
"torch.set_printoptions",
"torch.set_flush_denormal",
"torch.tensor",
"torch.sparse_coo_tensor",
"torch.as_tensor",
"torch.as_strided",
"torch.from_numpy",
"torch.zeros",
"torch.zeros_like",
"torch.ones",
"torch.ones_like",
"torch.arange",
"torch.range",
"torch.linspace",
"torch.logspace",
"torch.eye",
"torch.empty",
"torch.empty_like",
"torch.empty_strided",
"torch.full",
"torch.full_like",
"torch.quantize_per_tensor",
"torch.quantize_per_channel",
"torch.cat",
"torch.chunk",
"torch.gather",
"torch.index_select",
"torch.masked_select",
"torch.narrow",
"torch.nonzero",
"torch.reshape",
"torch.split",
"torch.squeeze",
"torch.stack",
"torch.t",
"torch.take",
"torch.transpose",
"torch.unbind",
"torch.unsqueeze",
"torch.where",
"torch._C.Generator",
"torch._C.Generator.device",
"torch._C.Generator.get_state",
"torch._C.Generator.initial_seed",
"torch._C.Generator.manual_seed",
"torch._C.Generator.seed",
"torch._C.Generator.set_state",
"torch.seed",
"torch.manual_seed",
"torch.initial_seed",
"torch.get_rng_state",
"torch.set_rng_state",
"torch.default_generator",
"torch.bernoulli",
"torch.multinomial",
"torch.normal",
"torch.poisson",
"torch.rand",
"torch.rand_like",
"torch.randint",
"torch.randint_like",
"torch.randn",
"torch.randn_like",
"torch.randperm",
"torch.quasirandom.SobolEngine",
"torch.quasirandom.SobolEngine.draw",
"torch.quasirandom.SobolEngine.fast_forward",
"torch.quasirandom.SobolEngine.reset",
"torch.save",
"torch.load",
"torch.get_num_threads",
"torch.set_num_threads",
"torch.get_num_interop_threads",
"torch.set_num_interop_threads",
"torch.no_grad",
"torch.enable_grad",
"torch.set_grad_enabled",
"torch.abs",
"torch.acos",
"torch.add",
"torch.addcdiv",
"torch.addcmul",
"torch.angle",
"torch.asin",
"torch.atan",
"torch.atan2",
"torch.bitwise_not",
"torch.bitwise_and",
"torch.bitwise_or",
"torch.bitwise_xor",
"torch.ceil",
"torch.clamp",
"torch.conj",
"torch.cos",
"torch.cosh",
"torch.div",
"torch.digamma",
"torch.erf",
"torch.erfc",
"torch.erfinv",
"torch.exp",
"torch.expm1",
"torch.floor",
"torch.floor_divide",
"torch.fmod",
"torch.frac",
"torch.imag",
"torch.lerp",
"torch.lgamma",
"torch.log",
"torch.log10",
"torch.log1p",
"torch.log2",
"torch.logical_and",
"torch.logical_not",
"torch.logical_or",
"torch.logical_xor",
"torch.mul",
"torch.mvlgamma",
"torch.neg",
"torch.polygamma",
"torch.pow",
"torch.real",
"torch.reciprocal",
"torch.remainder",
"torch.round",
"torch.rsqrt",
"torch.sigmoid",
"torch.sign",
"torch.sin",
"torch.sinh",
"torch.sqrt",
"torch.square",
"torch.tan",
"torch.tanh",
"torch.true_divide",
"torch.trunc",
"torch.argmax",
"torch.argmin",
"torch.dist",
"torch.logsumexp",
"torch.mean",
"torch.median",
"torch.mode",
"torch.norm",
"torch.prod",
"torch.std",
"torch.std_mean",
"torch.sum",
"torch.unique",
"torch.unique_consecutive",
"torch.var",
"torch.var_mean",
"torch.allclose",
"torch.argsort",
"torch.eq",
"torch.equal",
"torch.ge",
"torch.gt",
"torch.isfinite",
"torch.isinf",
"torch.isnan",
"torch.kthvalue",
"torch.le",
"torch.lt",
"torch.max",
"torch.min",
"torch.ne",
"torch.sort",
"torch.topk",
"torch.fft",
"torch.ifft",
"torch.rfft",
"torch.irfft",
"torch.stft",
"torch.bartlett_window",
"torch.blackman_window",
"torch.hamming_window",
"torch.hann_window",
"torch.bincount",
"torch.broadcast_tensors",
"torch.cartesian_prod",
"torch.cdist",
"torch.combinations",
"torch.cross",
"torch.cummax",
"torch.cummin",
"torch.cumprod",
"torch.cumsum",
"torch.diag",
"torch.diag_embed",
"torch.diagflat",
"torch.diagonal",
"torch.einsum",
"torch.flatten",
"torch.flip",
"torch.rot90",
"torch.histc",
"torch.meshgrid",
"torch.renorm",
"torch.repeat_interleave",
"torch.roll",
"torch.tensordot",
"torch.trace",
"torch.tril",
"torch.tril_indices",
"torch.triu",
"torch.triu_indices",
"torch.addbmm",
"torch.addmm",
"torch.addmv",
"torch.addr",
"torch.baddbmm",
"torch.bmm",
"torch.chain_matmul",
"torch.cholesky",
"torch.cholesky_inverse",
"torch.cholesky_solve",
"torch.dot",
"torch.eig",
"torch.geqrf",
"torch.ger",
"torch.inverse",
"torch.det",
"torch.logdet",
"torch.slogdet",
"torch.lstsq",
"torch.lu",
"torch.lu_solve",
"torch.lu_unpack",
"torch.matmul",
"torch.matrix_power",
"torch.matrix_rank",
"torch.mm",
"torch.mv",
"torch.orgqr",
"torch.ormqr",
"torch.pinverse",
"torch.qr",
"torch.solve",
"torch.svd",
"torch.svd_lowrank",
"torch.pca_lowrank",
"torch.symeig",
"torch.lobpcg",
"torch.trapz",
"torch.triangular_solve",
"torch.compiled_with_cxx11_abi",
"torch.result_type",
"torch.can_cast",
"torch.promote_types"
]
\ No newline at end of file
......@@ -2,6 +2,7 @@ attrdict>=2.0.1
Click>=7.0
Flask>=1.1.1
Flask-Cors>=3.0.8
google-pasta>=0.1.8
gunicorn>=19.9.0
itsdangerous>=1.1.0
Jinja2>=2.10.1
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册