提交 8e9d0b6b 编写于 作者: Q quyongxiu1

update warn infos and add interfaces for prompt infos and get ms api name

上级 83c104c4
...@@ -327,6 +327,48 @@ def load_json_file(file_path): ...@@ -327,6 +327,48 @@ def load_json_file(file_path):
return info return info
def get_corresponding_ms_name(pt_name):
"""
Get corresponding MindSpore op name for PyTorch name according to the mappings in mindconverter.
Args:
pt_name: PyTorch op name, whether shortened form or full name is available.
Returns:
str, full MindSpore op name, None if the op is not supported in mindconverter.
Raises:
ValueError, if get shortened form of MindSpore name not starts with `P` or 'nn', which means it is wrong in
the mappings file.
"""
helper = ALL_MAPPING.get(pt_name)
if helper is None:
return None
ms_name = helper.ms_api.name
if ms_name.startswith('nn.'):
full_ms_name = 'mindspore.' + ms_name
elif ms_name.startswith('P.'):
full_ms_name = 'mindspore.ops.operations.' + ms_name[len('P.'):]
else:
raise ValueError('check your mapping infos, the corresponding mindspore op name may wrong for torch op : '
'{}'.format(pt_name))
return full_ms_name
def get_prompt_info(pt_name):
"""
Get prompt info for PyTorch op name.
Args:
pt_name: PyTorch op name, whether shortened form or full name is available.
Returns:
str, prompt info on the op, None if no prompt info for the op.
"""
prompt_dict = {**UNSUPPORTED_WARN_INFOS, **SUPPORTED_WARN_INFOS}
return prompt_dict.get(pt_name)
# ---------------------------- mappings ---------------------------- # ---------------------------- mappings ----------------------------
NN_MAPPING_PATH = os.path.realpath(os.path.join(os.path.dirname(__file__), 'mappings/nn_mappings.json')) NN_MAPPING_PATH = os.path.realpath(os.path.join(os.path.dirname(__file__), 'mappings/nn_mappings.json'))
NN_MAPPING = get_mapping_from_file(NN_MAPPING_PATH) NN_MAPPING = get_mapping_from_file(NN_MAPPING_PATH)
...@@ -390,72 +432,98 @@ ALL_UNSUPPORTED = NN_UNSUPPORTED + F_UNSUPPORTED + TORCH_DOT_UNSUPPORTED + TENSO ...@@ -390,72 +432,98 @@ ALL_UNSUPPORTED = NN_UNSUPPORTED + F_UNSUPPORTED + TORCH_DOT_UNSUPPORTED + TENSO
UNSUPPORTED_WARN_INFOS = { UNSUPPORTED_WARN_INFOS = {
"nn.AdaptiveAvgPool2d": "maybe could convert to P.ReduceMean", "nn.AdaptiveAvgPool2d": "Maybe could convert to mindspore.ops.operations.ReduceMean.",
"nn.AvgPool1d": "maybe could convert to nn.AvgPool1d", "nn.AvgPool1d": "Maybe could convert to mindspore.nn.AvgPool1d.",
"nn.ConvTranspose2d": "maybe could convert to nn.Conv2dTranspose", "nn.ConvTranspose2d": "Maybe could convert to mindspore.nn.Conv2dTranspose.",
"nn.CrossEntropyLoss": "maybe could convert to nn.SoftmaxCrossEntropyWithLogits", "nn.CrossEntropyLoss": "Maybe could convert to mindspore.nn.SoftmaxCrossEntropyWithLogits.",
"nn.Embedding": "maybe could convert to nn.Embedding", "nn.Embedding": "Maybe could convert to mindspore.nn.Embedding.",
"nn.GroupNorm": "maybe could convert to nn.GroupNorm", "nn.GroupNorm": "Maybe could convert to mindspore.nn.GroupNorm.",
"nn.MSELoss": "maybe could convert to nn.MSELoss", "nn.MSELoss": "Maybe could convert to mindspore.nn.MSELoss.",
"nn.LSTM": "maybe could convert to nn.LSTM", "nn.LSTM": "Maybe could convert to mindspore.nn.LSTM.",
"nn.LSTMCell": "maybe could convert to nn.LSTMCell", "nn.LSTMCell": "Maybe could convert to mindspore.nn.LSTMCell.",
"nn.ModuleList": "maybe could convert to nn.CellList", "nn.ModuleList": "Maybe could convert to mindspore.nn.CellList.",
"nn.SmoothL1Loss": "maybe could convert to nn.SmoothL1Loss", "nn.SmoothL1Loss": "Maybe could convert to mindspore.nn.SmoothL1Loss.",
"nn.Tanh": "maybe could convert to nn.Tanh", "nn.Tanh": "Maybe could convert to mindspore.nn.Tanh.",
"nn.Upsample": "maybe could convert to P.ResizeBilinear", "nn.Upsample": "Maybe could convert to mindspore.ops.operations.ResizeBilinear.",
"nn.L1Loss": "maybe could convert to nn.L1Loss", "nn.L1Loss": "Maybe could convert to mindspore.nn.L1Loss.",
"nn.Parameter": "maybe could convert to mindspore.Parameter", "nn.Parameter": "Maybe could convert to mindspore.Parameter.",
"nn.ParameterList": "maybe could convert to mindspore.ParameterTuple", "nn.ParameterList": "Maybe could convert to mindspore.ParameterTuple.",
"nn.Unfold": "maybe could convert to nn.Unfold", "nn.Unfold": "Maybe could convert to mindspore.nn.Unfold.",
"nn.PixelShuffle": "maybe could convert to P.DepthToSpace", "nn.PixelShuffle": "Maybe could convert to mindspore.ops.operations.DepthToSpace.",
"F.adaptive_avg_pool2d": "maybe could convert to P.ReduceMean", "F.adaptive_avg_pool2d": "Maybe could convert to mindspore.ops.operations.ReduceMean.",
"F.conv2d": "maybe could convert to mindspore.ops.operations.Conv2D", "F.conv2d": "Maybe could convert to mindspore.ops.operations.Conv2D.",
"F.dropout": "please use nn.Dropout in __init__()", "F.dropout": "please use mindspore.nn.Dropout in __init__().",
"F.interpolate": "maybe could convert to P.ResizeBilinear", "F.interpolate": "Maybe could convert to mindspore.ops.operations.ResizeBilinear.",
"torch.bmm": "maybe could convert to P.BatchMatMul", "F.one_hot": "Maybe could convert to mindspore.ops.operations.OneHot.",
"torch.cumsum": "maybe could convert to P.CumSum", "torch.bmm": "Maybe could convert to mindspore.ops.operations.BatchMatMul.",
"F.relu": "maybe could convert to P.ReLU", "torch.cumsum": "Maybe could convert to mindspore.ops.operations.CumSum.",
"F.pad": "maybe could convert to P.Pad", "F.relu": "Maybe could convert to mindspore.ops.operations.ReLU.",
"F.softmax": "maybe could convert to P.Softmax", "F.pad": "Maybe could convert to mindspore.ops.operations.Pad.",
"torch.clamp": "maybe could convert to mindspore.ops.composite.clip_by_value", "F.softmax": "Maybe could convert to mindspore.ops.operations.Softmax.",
"torch.eq": "maybe could convert to P.Equal", "torch.clamp": "Maybe could convert to mindspore.ops.composite.clip_by_value.",
"torch.load": "maybe could convert to mindspore.train.serialization.load_checkpoint", "torch.eq": "Maybe could convert to mindspore.ops.operations.Equal.",
"torch.matmul": "maybe could convert to P.MatMul", "torch.load": "Maybe could convert to mindspore.train.serialization.load_checkpoint.",
"torch.max": "try to use P.ArgMaxWithValue, notice that two values are returned by P.ArgMaxWithValue", "torch.matmul": "Maybe could convert to mindspore.ops.operations.MatMul.",
"torch.mean": "maybe could convert to P.ReduceMean", "torch.max": "try to use P.ArgMaxWithValue, notice that two values are returned by mindspore.ops.operations."
"torch.min": "try to use P.ArgMinWithValue, notice that two values are returned by P.ArgMinWithValue", "ArgMaxWithValue.",
"torch.mm": "maybe could convert to P.MatMul", "torch.mean": "Maybe could convert to mindspore.ops.operations.ReduceMean.",
"torch.mul": "maybe could convert to P.Mul", "torch.min": "try to use P.ArgMinWithValue, notice that two values are returned by mindspore.ops.operations."
"torch.norm": "maybe could convert to nn.Norm", "ArgMinWithValue.",
"torch.numel": "maybe could convert to P.Size", "torch.mm": "Maybe could convert to mindspore.ops.operations.MatMul.",
"F.one_hot": "maybe could convert to P.OneHot", "torch.mul": "Maybe could convert to mindspore.ops.operations.Mul.",
"torch.ones_like": "maybe could convert to P.OnesLike", "torch.norm": "Maybe could convert to mindspore.nn.Norm.",
"torch.randn": "maybe could convert to P.TruncatedNormal", "torch.numel": "Maybe could convert to mindspore.ops.operations.Size.",
"torch.round": "maybe could convert to P.Round", "torch.ones_like": "Maybe could convert to mindspore.ops.operations.OnesLike.",
"torch.save": "maybe could convert to mindspore.train.serialization.save_checkpoint", "torch.randn": "Maybe could convert to mindspore.ops.operations.TruncatedNormal.",
"torch.sigmoid": "maybe could convert to P.Sigmoid", "torch.round": "Maybe could convert to mindspore.ops.operations.Round.",
"torch.split": "maybe could convert to P.Split", "torch.save": "Maybe could convert to mindspore.train.serialization.save_checkpoint.",
"torch.squeeze": "maybe could convert to P.Squeeze", "torch.sigmoid": "Maybe could convert to mindspore.ops.operations.Sigmoid.",
"torch.stack": "maybe could convert to P.Pack", "torch.split": "Maybe could convert to mindspore.ops.operations.Split.",
"torch.sum": "maybe could convert to mindspore.ops.operations.ReduceSum", "torch.squeeze": "Maybe could convert to mindspore.ops.operations.Squeeze.",
"torch.tanh": "maybe could convert to mindspore.ops.operations.Tanh", "torch.stack": "Maybe could convert to mindspore.ops.operations.Pack.",
"torch.tensor": "maybe could convert to mindspore.Tensor", "torch.sum": "Maybe could convert to mindspore.ops.operations.ReduceSum.",
"torch.transpose": "maybe could convert to P.Transpose", "torch.tanh": "Maybe could convert to mindspore.ops.operations.Tanh.",
"torch.unsqueeze": "maybe could convert to P.ExpandDims", "torch.tensor": "Maybe could convert to mindspore.Tensor.",
"torch.zeros_like": "maybe could convert to P.ZerosLike", "torch.transpose": "Maybe could convert to mindspore.ops.operations.Transpose.",
".chunk": "maybe could convert to P.Split", "torch.unsqueeze": "Maybe could convert to mindspore.ops.operations.ExpandDims.",
".fill_": "maybe could convert to P.Fill", "torch.zeros_like": "Maybe could convert to mindspore.ops.operations.ZerosLike.",
".float": "maybe could convert to P.Cast", ".chunk": "Maybe could convert to mindspore.ops.operations.Split.",
".mm": "maybe could convert to P.MatMul", ".fill_": "Maybe could convert to mindspore.ops.operations.Fill.",
"mul": "maybe could convert to P.Mul", ".float": "Maybe could convert to mindspore.ops.operations.Cast.",
".pow": "maybe could convert to P.Pow", ".mm": "Maybe could convert to mindspore.ops.operations.MatMul.",
".round": "maybe could convert to P.Round", ".mul": "Maybe could convert to mindspore.ops.operations.Mul.",
".scatter": "maybe could convert to P.ScatterNd", ".pow": "Maybe could convert to mindspore.ops.operations.Pow.",
"sigmoid": "maybe could convert to nn.Sigmoid", ".round": "Maybe could convert to mindspore.ops.operations.Round.",
".sign": "maybe could convert to P.Sign", ".scatter": "Maybe could convert to mindspore.ops.operations.ScatterNd.",
".sqrt": "maybe could convert to P.Sqrt", ".sigmoid": "Maybe could convert to mindspore.nn.Sigmoid.",
".sub": "maybe could convert to P.Sub", ".sign": "Maybe could convert to mindspore.ops.operations.Sign.",
".transpose": "maybe could convert to P.Transpose", ".sqrt": "Maybe could convert to mindspore.ops.operations.Sqrt.",
".unsqueeze": "maybe could convert to P.ExpandDims", ".sub": "Maybe could convert to mindspore.ops.operations.Sub.",
".zero_": "maybe could convert to P.ZerosLike", ".transpose": "Maybe could convert to mindspore.ops.operations.Transpose.",
".unsqueeze": "Maybe could convert to mindspore.ops.operations.ExpandDims.",
".zero_": "Maybe could convert to mindspore.ops.operations.ZerosLike.",
}
NN_UNSUPPORTED_INFOS = {k: v for k, v in UNSUPPORTED_WARN_INFOS.items() if k.startswith('nn.')}
TORCH_NN_UNSUPPORTED_INFOS = {('torch.' + k): v for k, v in NN_UNSUPPORTED_INFOS.items()}
F_UNSUPPORTED_INFOS = {k: v for k, v in UNSUPPORTED_WARN_INFOS.items() if k.startswith('F.')}
NN_FUNCTIONAL_UNSUPPORTED_INFOS = {'nn.functional.' + k[len('F.'):]: v for k, v in F_UNSUPPORTED_INFOS.items()}
TORCH_NN_FUNCTIONAL_UNSUPPORTED_INFOS = {'torch.nn.functional.' + k[len('F.'):]: v for k, v in
F_UNSUPPORTED_INFOS.items()}
UNSUPPORTED_WARN_INFOS.update(TORCH_NN_UNSUPPORTED_INFOS)
UNSUPPORTED_WARN_INFOS.update(NN_FUNCTIONAL_UNSUPPORTED_INFOS)
UNSUPPORTED_WARN_INFOS.update(TORCH_NN_FUNCTIONAL_UNSUPPORTED_INFOS)
SUPPORTED_WARN_INFOS = {
"torch.eye": "Pay attention to use right mindspore data type.",
"nn.Linear": "Pay attention to reshape the input to 2 dims if it is 3 dims before, because MindSpore.nn.Dense only "
"support 2-dim input.",
".view": "Only float Tensor is supported in mindspore.ops.operations.Reshape.",
".reshape": "Only float Tensor is supported in mindspore.ops.operations.Reshape."
} }
NN_SUPPORTED_INFOS = {k: v for k, v in SUPPORTED_WARN_INFOS.items() if k.startswith('nn.')}
TORCH_NN_SUPPORTED_INFOS = {('torch.' + k): v for k, v in NN_SUPPORTED_INFOS.items()}
SUPPORTED_WARN_INFOS.update(TORCH_NN_SUPPORTED_INFOS)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册