提交 16e94c3f 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!135 torch2ms convert

Merge pull request !135 from quyongxiu1/br_qyx_0520
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Create a logger."""
from mindinsight.utils.log import setup_logger
logger = setup_logger("mindconverter", "mindconverter")
此差异已折叠。
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless REQUIRED by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""main module"""
import inspect
import copy
import importlib
import os
import stat
from mindinsight.mindconverter.config import ALL_MAPPING
from mindinsight.mindconverter.config import NN_LIST
from mindinsight.mindconverter.config import ALL_TORCH_APIS
from mindinsight.mindconverter.config import ALL_2P_LIST
from mindinsight.mindconverter.config import UNSUPPORTED_WARN_INFOS
from mindinsight.mindconverter.common.log import logger
def is_local_defined(obj, member):
"""
Check if obj and member are both defined in the same source file.
Args:
obj (Union[object, module]): A module or a class.
member (func): A function of obj.
Returns:
bool, True or False.
"""
srcfile = inspect.getsourcefile(obj)
return inspect.getsourcefile(member) == srcfile
def is_valid_module(obj, member):
"""
Check if obj and member defined in same source file and member is inherited from torch.nn.Module.
Args:
obj (Union[object, module]): A module or a class.
member (func): A function.
Returns:
bool, True or False.
"""
return inspect.isclass(member) and (member.__base__.__name__ == 'Module') and is_local_defined(obj, member)
def is_valid_function(obj, member):
"""
Check if member is function and defined in the file same as obj.
Args:
obj (Union[object, module]: The obj.
member (func): The func.
Returns:
bool, True or False.
"""
return inspect.isfunction(member) and is_local_defined(obj, member)
def find_left_parentheses(string, right):
"""
Find index of the first left parenthesis.
Args:
string (str): A line of code.
right (int): Max index of string, same as `len(string) -1`.
Returns:
int, index of the first parenthesis.
Raises:
ValueError: If line of code doesn't contain any pair of `()` or `(` and `)` are not paired.
"""
if string[right] != ')':
raise ValueError('code [{}] at index {} not ")".'.format(string, right))
stack = []
for i in range(right, -1, -1):
if string[i] == ')':
stack.append(')')
elif string[i] == '(':
stack.pop()
if not stack:
return i
raise ValueError("{} should contain ()".format(string))
def find_right_parentheses(string, left):
"""
Find first index of right parenthesis which make all left parenthesis make sense.
Args:
string (str): A line of code.
left (int): Start index of string to find from.
Returns:
int, index of the found right parenthesis.
Raises:
ValueError: If line of code doesn't contain any pair of `()` or `(` and `)` are not paired.
"""
stack = []
for i in range(left, len(string)):
if string[i] == '(':
stack.append('(')
elif string[i] == ')':
stack.pop()
if not stack:
return i
raise ValueError("{} should contain ()".format(string))
def get_call_name(code, end):
"""
Traverse code in a reversed function from index end and get the call name and start index of the call name, if call
name not found, return a null character string and -1
Args:
code (str): The str of code to find from.
end (int): Start index to find.
Returns:
str, founded api name if found, else a null character string.
int, start index of founded api name, -1 if api name not found
"""
stack = []
for i in range(end - 1, -1, -1):
if code[i] in ["(", "[", "{"]:
if stack:
stack.pop()
else:
return code[i + 1:end], i + 1
elif code[i] in [")", "]", "}"]:
stack.append(code[i])
elif stack:
continue
elif not (code[i].isalpha() or code[i].isdigit() or code[i] == '_' or code[i] == '.'):
return code[i + 1:end], i + 1
return "", -1
def convert_api(code, start, api_name=""):
"""
Convert api_name in code to MindSpore api with start as a start index, if api_name is a python api, code will not
convert.
Args:
code (str): The str code to convert.
start (int): The index of code to start convert from.
api_name (str): The api name to convert.
Returns:
str, the converted code.
int, index of converted api_name in code.
"""
# handle format like .shape(
if api_name.startswith('.'):
call_name, new_start = get_call_name(code, start)
if start == -1 or call_name == "self":
return code, start + 1
else:
call_name = api_name
new_start = start
# find full api expected to be converted. eg:expr="nn.Conv2d(1,2,3)" args_str="(1,2,3)"
left = code.find("(", start)
if left == -1:
raise ValueError('"(" not found, {} should work with "("'.format(call_name))
right = find_right_parentheses(code, left)
end = right
expr = code[start:end + 1]
args_str = code[left:right + 1]
map_helper = ALL_MAPPING[api_name]
new_expr = map_helper.convert(call_name, args_str)
next_newline = code.find("\n", end + 1)
fill_num = (expr.count("\n") - new_expr.count("\n"))
if next_newline != -1:
code = code[:new_start] + new_expr + code[end + 1:next_newline] + ("\n" * fill_num) + code[next_newline:]
else:
code = code[:new_start] + new_expr + ")" + ("\n" * fill_num) + code[end + 2:]
return code, start + len(map_helper.ms_api.name)
def find_api(code, i, is_forward):
"""
Find api name from code with a start index i, check api name ok with a is_forward condition.
Args:
code (str): The code from which to find api name.
i (int): The start index to find.
is_forward (bool): Check if the found api name ok.
Returns:
str, api name if find api name and check ok with is_forward condition, else a null character string.
"""
if code[i:].startswith("nn.") \
or code[i:].startswith("F.") \
or code[i:].startswith("torch.") \
or code[i:].startswith('.'):
j = code.find('(', i)
if j != -1 and code[i:j] in ALL_TORCH_APIS:
api_name = code[i:j]
if (not is_forward and api_name in NN_LIST) or (is_forward and api_name in ALL_2P_LIST):
return api_name
return ""
def convert_function(fun_name, fun, is_forward):
"""
Convert a PyTorch function into MindSpore function.
Args:
fun_name (str): The str of function name.
fun (func): The function to convert.
is_forward (bool): If the function is defined in forward function in nn.Module in torch.
Returns:
dict, old code and converted code map if convert happens, else {}.
"""
_, line_no = inspect.getsourcelines(fun)
logger.info("Line %3d: start converting function %s()", line_no, fun_name)
code = inspect.getsource(fun)
code_saved = copy.copy(code)
i = 0
while i < len(code):
api_name = find_api(code, i, is_forward)
if api_name:
line_no1 = line_no + code[:i].count('\n')
if api_name in ALL_MAPPING:
logger.info("Line %3d start converting API: %s", line_no1, api_name)
code, i = convert_api(code, i, api_name)
continue
warn_info = ". " + UNSUPPORTED_WARN_INFOS[api_name] if api_name in UNSUPPORTED_WARN_INFOS else ""
logger.warning("Line %3d: found unsupported API: %s%s", line_no1, api_name, warn_info)
i += 1
return {code_saved: code} if code_saved != code else {}
def judge_forward(name, forward_list):
"""
Check if function is a forward function.
Args:
name (str): The function name.
forward_list (set): A set of forward function.
Returns:
bool, True or False
"""
is_forward = name in forward_list or name.split(".")[-1] == "forward"
if is_forward:
logger.debug("%s is a forward function", name)
return is_forward
def convert_module(module_name, module, forward_list):
"""
Convert a PyTorch module code into MindSpore module code.
Args:
module_name (str): The module's name.
module (module): The module to convert.
forward_list (set): A set of forward function.
Returns:
dict, map of old code and converted code.
"""
_, line_no = inspect.getsourcelines(module)
logger.info("Line {:3d}: start converting nn.Module {}".format(line_no, module_name))
mapped = {}
for name, member in inspect.getmembers(module):
if is_valid_function(module, member):
is_forward = judge_forward("{}.{}".format(module_name, name), forward_list)
mapped.update(convert_function(name, member, is_forward))
return mapped
def get_mapping(import_mod, forward_list):
"""
Convert code of a module and get mapping of old code and convert code.
Args:
import_mod (module): The module to convert.
forward_list (set): A set of forward function.
Returns:
dict, mapping for old code and converted code of the module
"""
mapping = {}
tasks = []
for name, member in inspect.getmembers(import_mod):
if is_valid_module(import_mod, member):
_, line_no = inspect.getsourcelines(member)
tasks.append((line_no, convert_module, (name, member, forward_list)))
elif is_valid_function(import_mod, member):
_, line_no = inspect.getsourcelines(member)
is_forward = judge_forward("{}.{}".format(import_mod, name), forward_list)
tasks.append((line_no, convert_function, (name, member, is_forward)))
tasks.sort()
for _, convert_fun, args in tasks:
mapping.update(convert_fun(*args))
return mapping
def convert(import_name, nn_module):
"""
The entrance for convert a module's code, code converted will be write to file called out.py.
Args:
import_name (str): The module from which to import the module to convert.
nn_module (str): Name of the module to convert.
"""
logger.info("Start converting %s.%s", import_name, nn_module)
import_mod = importlib.import_module(import_name)
forward_list = set()
logger.debug("Forward_list: %s", forward_list)
# replace python function under nn.Modlue
mapping = get_mapping(import_mod, forward_list)
code = inspect.getsource(import_mod)
for key, value in mapping.items():
code = code.replace(key, value)
code = 'import mindspore.ops.operations as P\n' + code
code = 'import mindspore.nn as nn\n' + code
code = 'import mindspore\n' + code
code = code.replace('import torch', '# import torch')
code = code.replace('from torch', '# from torch')
code = code.replace('(nn.Module):', '(nn.Cell):')
code = code.replace('forward(', 'construct(')
code = code.replace('nn.Linear', 'nn.Dense')
code = code.replace('(nn.Sequential)', '(nn.SequentialCell)')
code = code.replace('nn.init.', 'pass # nn.init.')
flags = os.O_WRONLY | os.O_CREAT | os.O_EXCL
modes = stat.S_IWUSR | stat.S_IRUSR
with os.fdopen(os.open('out.py', flags, modes), 'w') as file:
file.write(code)
logger.info("Convert success. Result is wrote to out.py\n")
if __name__ == '__main__':
convert('torchvision.models.resnet', 'resnet18')
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Enums."""
from enum import Enum
class RequriedType(Enum):
"""If param is required"""
REQUIRED = 1
UNREQUIRED = 2
[
"torch.nn.functional.conv1d",
"torch.nn.functional.conv2d",
"torch.nn.functional.conv3d",
"torch.nn.functional.conv_transpose1d",
"torch.nn.functional.conv_transpose2d",
"torch.nn.functional.conv_transpose3d",
"torch.nn.functional.unfold",
"torch.nn.functional.fold",
"torch.nn.functional.avg_pool1d",
"torch.nn.functional.avg_pool2d",
"torch.nn.functional.avg_pool3d",
"torch.nn.functional.max_pool1d",
"torch.nn.functional.max_pool2d",
"torch.nn.functional.max_pool3d",
"torch.nn.functional.max_unpool1d",
"torch.nn.functional.max_unpool2d",
"torch.nn.functional.max_unpool3d",
"torch.nn.functional.lp_pool1d",
"torch.nn.functional.lp_pool2d",
"torch.nn.functional.adaptive_max_pool1d",
"torch.nn.functional.adaptive_max_pool2d",
"torch.nn.functional.adaptive_max_pool3d",
"torch.nn.functional.adaptive_avg_pool1d",
"torch.nn.functional.adaptive_avg_pool2d",
"torch.nn.functional.adaptive_avg_pool3d",
"torch.nn.functional.threshold",
"torch.nn.functional.threshold_",
"torch.nn.functional.relu",
"torch.nn.functional.relu_",
"torch.nn.functional.hardtanh",
"torch.nn.functional.hardtanh_",
"torch.nn.functional.relu6",
"torch.nn.functional.elu",
"torch.nn.functional.elu_",
"torch.nn.functional.selu",
"torch.nn.functional.celu",
"torch.nn.functional.leaky_relu",
"torch.nn.functional.leaky_relu_",
"torch.nn.functional.prelu",
"torch.nn.functional.rrelu",
"torch.nn.functional.rrelu_",
"torch.nn.functional.glu",
"torch.nn.functional.gelu",
"torch.nn.functional.logsigmoid",
"torch.nn.functional.hardshrink",
"torch.nn.functional.tanhshrink",
"torch.nn.functional.softsign",
"torch.nn.functional.softplus",
"torch.nn.functional.softmin",
"torch.nn.functional.softmax",
"torch.nn.functional.softshrink",
"torch.nn.functional.gumbel_softmax",
"torch.nn.functional.log_softmax",
"torch.nn.functional.tanh",
"torch.nn.functional.sigmoid",
"torch.nn.functional.batch_norm",
"torch.nn.functional.instance_norm",
"torch.nn.functional.layer_norm",
"torch.nn.functional.local_response_norm",
"torch.nn.functional.normalize",
"torch.nn.functional.linear",
"torch.nn.functional.bilinear",
"torch.nn.functional.dropout",
"torch.nn.functional.alpha_dropout",
"torch.nn.functional.dropout2d",
"torch.nn.functional.dropout3d",
"torch.nn.functional.embedding",
"torch.nn.functional.embedding_bag",
"torch.nn.functional.one_hot",
"torch.nn.functional.pairwise_distance",
"torch.nn.functional.cosine_similarity",
"torch.nn.functional.pdist",
"torch.nn.functional.binary_cross_entropy",
"torch.nn.functional.binary_cross_entropy_with_logits",
"torch.nn.functional.poisson_nll_loss",
"torch.nn.functional.cosine_embedding_loss",
"torch.nn.functional.cross_entropy",
"torch.nn.functional.ctc_loss",
"torch.nn.functional.log_softmax",
"torch.nn.functional.hinge_embedding_loss",
"torch.nn.functional.kl_div",
"torch.nn.functional.l1_loss",
"torch.nn.functional.mse_loss",
"torch.nn.functional.margin_ranking_loss",
"torch.nn.functional.multilabel_margin_loss",
"torch.nn.functional.multilabel_soft_margin_loss",
"torch.nn.functional.multi_margin_loss",
"torch.nn.functional.nll_loss",
"torch.nn.functional.smooth_l1_loss",
"torch.nn.functional.soft_margin_loss",
"torch.nn.functional.triplet_margin_loss",
"torch.nn.functional.pixel_shuffle",
"torch.nn.functional.pixel_shuffle",
"torch.nn.functional.pad",
"torch.nn.functional.interpolate",
"torch.nn.functional.upsample",
"torch.nn.functional.interpolate",
"torch.nn.functional.upsample_nearest",
"torch.nn.functional.interpolate",
"torch.nn.functional.upsample_bilinear",
"torch.nn.functional.interpolate",
"torch.nn.functional.grid_sample",
"torch.nn.functional.affine_grid"
]
\ No newline at end of file
[
"nn.Module",
"nn.CELU",
"nn.ELU",
"nn.GLU",
"nn.GELU",
"nn.Hardshrink",
"nn.Hardtanh",
"nn.LeakyReLU",
"nn.LogSigmoid",
"nn.LogSoftmax",
"nn.PReLU",
"nn.RReLU",
"nn.ReLU",
"nn.ReLU6",
"nn.SELU",
"nn.Sigmoid",
"nn.Softmax",
"nn.Softmax2d",
"nn.Softmin",
"nn.Softplus",
"nn.Softshrink",
"nn.Softsign",
"nn.Tanh",
"nn.Tanhshrink",
"nn.Threshold",
"nn.MultiheadAttention",
"nn.AdaptiveLogSoftmaxWithLoss",
"nn.BatchNorm1d",
"nn.BatchNorm2d",
"nn.BatchNorm3d",
"nn.SyncBatchNorm",
"nn.Container",
"nn.ModuleDict",
"nn.ModuleList",
"nn.ParameterDict",
"nn.ParameterList",
"nn.Sequential",
"nn.Conv1d",
"nn.Conv2d",
"nn.Conv3d",
"nn.ConvTranspose1d",
"nn.ConvTranspose2d",
"nn.ConvTranspose3d",
"nn.CosineSimilarity",
"nn.PairwiseDistance",
"nn.AlphaDropout",
"nn.Dropout",
"nn.Dropout2d",
"nn.Dropout3d",
"nn.FeatureAlphaDropout",
"nn.Fold",
"nn.Unfold",
"nn.InstanceNorm1d",
"nn.InstanceNorm2d",
"nn.InstanceNorm3d",
"nn.Bilinear",
"nn.Identity",
"nn.Linear",
"nn.BCELoss",
"nn.BCEWithLogitsLoss",
"nn.CTCLoss",
"nn.CosineEmbeddingLoss",
"nn.CrossEntropyLoss",
"nn.HingeEmbeddingLoss",
"nn.KLDivLoss",
"nn.L1Loss",
"nn.MSELoss",
"nn.MarginRankingLoss",
"nn.MultiLabelMarginLoss",
"nn.MultiLabelSoftMarginLoss",
"nn.MultiMarginLoss",
"nn.NLLLoss",
"nn.NLLLoss2d",
"nn.PoissonNLLLoss",
"nn.SmoothL1Loss",
"nn.SoftMarginLoss",
"nn.TripletMarginLoss",
"nn.Module",
"nn.CrossMapLRN2d",
"nn.GroupNorm",
"nn.LayerNorm",
"nn.LocalResponseNorm",
"nn.ConstantPad1d",
"nn.ConstantPad2d",
"nn.ConstantPad3d",
"nn.ReflectionPad1d",
"nn.ReflectionPad2d",
"nn.ReplicationPad1d",
"nn.ReplicationPad2d",
"nn.ReplicationPad3d",
"nn.ZeroPad2d",
"nn.PixelShuffle",
"nn.AdaptiveAvgPool1d",
"nn.AdaptiveAvgPool2d",
"nn.AdaptiveAvgPool3d",
"nn.AdaptiveMaxPool1d",
"nn.AdaptiveMaxPool2d",
"nn.AdaptiveMaxPool3d",
"nn.AvgPool1d",
"nn.AvgPool2d",
"nn.AvgPool3d",
"nn.FractionalMaxPool2d",
"nn.FractionalMaxPool3d",
"nn.LPPool1d",
"nn.LPPool2d",
"nn.MaxPool1d",
"nn.MaxPool2d",
"nn.MaxPool3d",
"nn.MaxUnpool1d",
"nn.MaxUnpool2d",
"nn.MaxUnpool3d",
"nn.GRU",
"nn.GRUCell",
"nn.LSTM",
"nn.LSTMCell",
"nn.RNN",
"nn.RNNBase",
"nn.RNNCell",
"nn.RNNCellBase",
"nn.Embedding",
"nn.EmbeddingBag",
"nn.Upsample",
"nn.UpsamplingBilinear2d",
"nn.UpsamplingNearest2d",
"nn.Transformer",
"nn.TransformerEncoder",
"nn.TransformerDecoder",
"nn.TransformerEncoderLayer",
"nn.TransformerDecoderLayer",
"nn.Parameter",
"nn.DataParallel"
]
\ No newline at end of file
[
".new_tensor",
".new_full",
".new_empty",
".new_ones",
".new_zeros",
".abs",
".abs_",
".acos",
".acos_",
".add",
".add_",
".addbmm",
".addbmm_",
".addcdiv",
".addcdiv_",
".addcmul",
".addcmul_",
".addmm",
".addmm_",
".addmv",
".addmv_",
".addr",
".addr_",
".allclose",
".angle",
".apply_",
".argmax",
".argmin",
".argsort",
".asin",
".asin_",
".as_strided",
".atan",
".atan2",
".atan2_",
".atan_",
".baddbmm",
".baddbmm_",
".bernoulli",
".bernoulli_",
".bfloat16",
".bincount",
".bitwise_not",
".bitwise_not_",
".bitwise_and",
".bitwise_and_",
".bitwise_or",
".bitwise_or_",
".bitwise_xor",
".bitwise_xor_",
".bmm",
".bool",
".byte",
".cauchy_",
".ceil",
".ceil_",
".char",
".cholesky",
".cholesky_inverse",
".cholesky_solve",
".chunk",
".clamp",
".clamp_",
".clone",
".contiguous",
".copy_",
".conj",
".cos",
".cos_",
".cosh",
".cosh_",
".cpu",
".cross",
".cuda",
".cummax",
".cummin",
".cumprod",
".cumsum",
".data_ptr",
".dequantize",
".det",
".dense_dim",
".diag",
".diag_embed",
".diagflat",
".diagonal",
".fill_diagonal_",
".digamma",
".digamma_",
".dim",
".dist",
".div",
".div_",
".dot",
".double",
".eig",
".element_size",
".eq",
".eq_",
".equal",
".erf",
".erf_",
".erfc",
".erfc_",
".erfinv",
".erfinv_",
".exp",
".exp_",
".expm1",
".expm1_",
".expand",
".expand_as",
".exponential_",
".fft",
".fill_",
".flatten",
".flip",
".float",
".floor",
".floor_",
".floor_divide",
".floor_divide_",
".fmod",
".fmod_",
".frac",
".frac_",
".gather",
".ge",
".ge_",
".geometric_",
".geqrf",
".ger",
".get_device",
".gt",
".gt_",
".half",
".hardshrink",
".histc",
".ifft",
".index_add_",
".index_add",
".index_copy_",
".index_copy",
".index_fill_",
".index_fill",
".index_put_",
".index_put",
".index_select",
".indices",
".int",
".int_repr",
".inverse",
".irfft",
".is_contiguous",
".is_complex",
".is_floating_point",
".is_pinned",
".is_set_to",
".is_shared",
".is_signed",
".item",
".kthvalue",
".le",
".le_",
".lerp",
".lerp_",
".lgamma",
".lgamma_",
".log",
".log_",
".logdet",
".log10",
".log10_",
".log1p",
".log1p_",
".log2",
".log2_",
".log_normal_",
".logsumexp",
".logical_and",
".logical_and_",
".logical_not",
".logical_not_",
".logical_or",
".logical_or_",
".logical_xor",
".logical_xor_",
".long",
".lstsq",
".lt",
".lt_",
".lu",
".lu_solve",
".map_",
".masked_scatter_",
".masked_scatter",
".masked_fill_",
".masked_fill",
".masked_select",
".matmul",
".matrix_power",
".max",
".mean",
".median",
".min",
".mm",
".mode",
".mul",
".mul_",
".multinomial",
".mv",
".mvlgamma",
".mvlgamma_",
".narrow",
".narrow_copy",
".ndimension",
".ne",
".ne_",
".neg",
".neg_",
".nelement",
".nonzero",
".norm",
".normal_",
".numel",
".numpy",
".orgqr",
".ormqr",
".permute",
".pin_memory",
".pinverse",
".polygamma",
".polygamma_",
".pow",
".pow_",
".prod",
".put_",
".qr",
".qscheme",
".q_scale",
".q_zero_point",
".q_per_channel_scales",
".q_per_channel_zero_points",
".q_per_channel_axis",
".random_",
".reciprocal",
".reciprocal_",
".record_stream",
".remainder",
".remainder_",
".renorm",
".renorm_",
".repeat",
".repeat_interleave",
".requires_grad_",
".reshape",
".reshape_as",
".resize_",
".resize_as_",
".rfft",
".roll",
".rot90",
".round",
".round_",
".rsqrt",
".rsqrt_",
".scatter",
".scatter_",
".scatter_add_",
".scatter_add",
".select",
".set_",
".share_memory_",
".short",
".sigmoid",
".sigmoid_",
".sign",
".sign_",
".sin",
".sin_",
".sinh",
".sinh_",
".size",
".slogdet",
".solve",
".sort",
".split",
".sparse_mask",
".sparse_dim",
".sqrt",
".sqrt_",
".square",
".square_",
".squeeze",
".squeeze_",
".std",
".stft",
".storage",
".storage_offset",
".storage_type",
".stride",
".sub",
".sub_",
".sum",
".sum_to_size",
".svd",
".symeig",
".t",
".t_",
".to",
".to_mkldnn",
".take",
".tan",
".tan_",
".tanh",
".tanh_",
".tolist",
".topk",
".to_sparse",
".trace",
".transpose",
".transpose_",
".triangular_solve",
".tril",
".tril_",
".triu",
".triu_",
".true_divide",
".true_divide_",
".trunc",
".trunc_",
".type",
".type_as",
".unbind",
".unfold",
".uniform_",
".unique",
".unique_consecutive",
".unsqueeze",
".unsqueeze_",
".values",
".var",
".view",
".view_as",
".where",
".zero_"
]
\ No newline at end of file
[
"torch.is_tensor",
"torch.is_storage",
"torch.is_complex",
"torch.is_floating_point",
"torch.set_default_dtype",
"torch.get_default_dtype",
"torch.set_default_tensor_type",
"torch.numel",
"torch.set_printoptions",
"torch.set_flush_denormal",
"torch.tensor",
"torch.sparse_coo_tensor",
"torch.as_tensor",
"torch.as_strided",
"torch.from_numpy",
"torch.zeros",
"torch.zeros_like",
"torch.ones",
"torch.ones_like",
"torch.arange",
"torch.range",
"torch.linspace",
"torch.logspace",
"torch.eye",
"torch.empty",
"torch.empty_like",
"torch.empty_strided",
"torch.full",
"torch.full_like",
"torch.quantize_per_tensor",
"torch.quantize_per_channel",
"torch.cat",
"torch.chunk",
"torch.gather",
"torch.index_select",
"torch.masked_select",
"torch.narrow",
"torch.nonzero",
"torch.reshape",
"torch.split",
"torch.squeeze",
"torch.stack",
"torch.t",
"torch.take",
"torch.transpose",
"torch.unbind",
"torch.unsqueeze",
"torch.where",
"torch._C.Generator",
"torch._C.Generator.device",
"torch._C.Generator.get_state",
"torch._C.Generator.initial_seed",
"torch._C.Generator.manual_seed",
"torch._C.Generator.seed",
"torch._C.Generator.set_state",
"torch.seed",
"torch.manual_seed",
"torch.initial_seed",
"torch.get_rng_state",
"torch.set_rng_state",
"torch.default_generator",
"torch.bernoulli",
"torch.multinomial",
"torch.normal",
"torch.poisson",
"torch.rand",
"torch.rand_like",
"torch.randint",
"torch.randint_like",
"torch.randn",
"torch.randn_like",
"torch.randperm",
"torch.quasirandom.SobolEngine",
"torch.quasirandom.SobolEngine.draw",
"torch.quasirandom.SobolEngine.fast_forward",
"torch.quasirandom.SobolEngine.reset",
"torch.save",
"torch.load",
"torch.get_num_threads",
"torch.set_num_threads",
"torch.get_num_interop_threads",
"torch.set_num_interop_threads",
"torch.no_grad",
"torch.enable_grad",
"torch.set_grad_enabled",
"torch.abs",
"torch.acos",
"torch.add",
"torch.addcdiv",
"torch.addcmul",
"torch.angle",
"torch.asin",
"torch.atan",
"torch.atan2",
"torch.bitwise_not",
"torch.bitwise_and",
"torch.bitwise_or",
"torch.bitwise_xor",
"torch.ceil",
"torch.clamp",
"torch.conj",
"torch.cos",
"torch.cosh",
"torch.div",
"torch.digamma",
"torch.erf",
"torch.erfc",
"torch.erfinv",
"torch.exp",
"torch.expm1",
"torch.floor",
"torch.floor_divide",
"torch.fmod",
"torch.frac",
"torch.imag",
"torch.lerp",
"torch.lgamma",
"torch.log",
"torch.log10",
"torch.log1p",
"torch.log2",
"torch.logical_and",
"torch.logical_not",
"torch.logical_or",
"torch.logical_xor",
"torch.mul",
"torch.mvlgamma",
"torch.neg",
"torch.polygamma",
"torch.pow",
"torch.real",
"torch.reciprocal",
"torch.remainder",
"torch.round",
"torch.rsqrt",
"torch.sigmoid",
"torch.sign",
"torch.sin",
"torch.sinh",
"torch.sqrt",
"torch.square",
"torch.tan",
"torch.tanh",
"torch.true_divide",
"torch.trunc",
"torch.argmax",
"torch.argmin",
"torch.dist",
"torch.logsumexp",
"torch.mean",
"torch.median",
"torch.mode",
"torch.norm",
"torch.prod",
"torch.std",
"torch.std_mean",
"torch.sum",
"torch.unique",
"torch.unique_consecutive",
"torch.var",
"torch.var_mean",
"torch.allclose",
"torch.argsort",
"torch.eq",
"torch.equal",
"torch.ge",
"torch.gt",
"torch.isfinite",
"torch.isinf",
"torch.isnan",
"torch.kthvalue",
"torch.le",
"torch.lt",
"torch.max",
"torch.min",
"torch.ne",
"torch.sort",
"torch.topk",
"torch.fft",
"torch.ifft",
"torch.rfft",
"torch.irfft",
"torch.stft",
"torch.bartlett_window",
"torch.blackman_window",
"torch.hamming_window",
"torch.hann_window",
"torch.bincount",
"torch.broadcast_tensors",
"torch.cartesian_prod",
"torch.cdist",
"torch.combinations",
"torch.cross",
"torch.cummax",
"torch.cummin",
"torch.cumprod",
"torch.cumsum",
"torch.diag",
"torch.diag_embed",
"torch.diagflat",
"torch.diagonal",
"torch.einsum",
"torch.flatten",
"torch.flip",
"torch.rot90",
"torch.histc",
"torch.meshgrid",
"torch.renorm",
"torch.repeat_interleave",
"torch.roll",
"torch.tensordot",
"torch.trace",
"torch.tril",
"torch.tril_indices",
"torch.triu",
"torch.triu_indices",
"torch.addbmm",
"torch.addmm",
"torch.addmv",
"torch.addr",
"torch.baddbmm",
"torch.bmm",
"torch.chain_matmul",
"torch.cholesky",
"torch.cholesky_inverse",
"torch.cholesky_solve",
"torch.dot",
"torch.eig",
"torch.geqrf",
"torch.ger",
"torch.inverse",
"torch.det",
"torch.logdet",
"torch.slogdet",
"torch.lstsq",
"torch.lu",
"torch.lu_solve",
"torch.lu_unpack",
"torch.matmul",
"torch.matrix_power",
"torch.matrix_rank",
"torch.mm",
"torch.mv",
"torch.orgqr",
"torch.ormqr",
"torch.pinverse",
"torch.qr",
"torch.solve",
"torch.svd",
"torch.svd_lowrank",
"torch.pca_lowrank",
"torch.symeig",
"torch.lobpcg",
"torch.trapz",
"torch.triangular_solve",
"torch.compiled_with_cxx11_abi",
"torch.result_type",
"torch.can_cast",
"torch.promote_types"
]
\ No newline at end of file
......@@ -2,6 +2,7 @@ attrdict>=2.0.1
Click>=7.0
Flask>=1.1.1
Flask-Cors>=3.0.8
google-pasta>=0.1.8
gunicorn>=19.9.0
itsdangerous>=1.1.0
Jinja2>=2.10.1
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册