提交 8e9d0b6b 编写于 作者: Q quyongxiu1

update warn infos and add interfaces for prompt infos and get ms api name

上级 83c104c4
......@@ -327,6 +327,48 @@ def load_json_file(file_path):
return info
def get_corresponding_ms_name(pt_name):
"""
Get corresponding MindSpore op name for PyTorch name according to the mappings in mindconverter.
Args:
pt_name: PyTorch op name, whether shortened form or full name is available.
Returns:
str, full MindSpore op name, None if the op is not supported in mindconverter.
Raises:
ValueError, if get shortened form of MindSpore name not starts with `P` or 'nn', which means it is wrong in
the mappings file.
"""
helper = ALL_MAPPING.get(pt_name)
if helper is None:
return None
ms_name = helper.ms_api.name
if ms_name.startswith('nn.'):
full_ms_name = 'mindspore.' + ms_name
elif ms_name.startswith('P.'):
full_ms_name = 'mindspore.ops.operations.' + ms_name[len('P.'):]
else:
raise ValueError('check your mapping infos, the corresponding mindspore op name may wrong for torch op : '
'{}'.format(pt_name))
return full_ms_name
def get_prompt_info(pt_name):
"""
Get prompt info for PyTorch op name.
Args:
pt_name: PyTorch op name, whether shortened form or full name is available.
Returns:
str, prompt info on the op, None if no prompt info for the op.
"""
prompt_dict = {**UNSUPPORTED_WARN_INFOS, **SUPPORTED_WARN_INFOS}
return prompt_dict.get(pt_name)
# ---------------------------- mappings ----------------------------
NN_MAPPING_PATH = os.path.realpath(os.path.join(os.path.dirname(__file__), 'mappings/nn_mappings.json'))
NN_MAPPING = get_mapping_from_file(NN_MAPPING_PATH)
......@@ -390,72 +432,98 @@ ALL_UNSUPPORTED = NN_UNSUPPORTED + F_UNSUPPORTED + TORCH_DOT_UNSUPPORTED + TENSO
UNSUPPORTED_WARN_INFOS = {
"nn.AdaptiveAvgPool2d": "maybe could convert to P.ReduceMean",
"nn.AvgPool1d": "maybe could convert to nn.AvgPool1d",
"nn.ConvTranspose2d": "maybe could convert to nn.Conv2dTranspose",
"nn.CrossEntropyLoss": "maybe could convert to nn.SoftmaxCrossEntropyWithLogits",
"nn.Embedding": "maybe could convert to nn.Embedding",
"nn.GroupNorm": "maybe could convert to nn.GroupNorm",
"nn.MSELoss": "maybe could convert to nn.MSELoss",
"nn.LSTM": "maybe could convert to nn.LSTM",
"nn.LSTMCell": "maybe could convert to nn.LSTMCell",
"nn.ModuleList": "maybe could convert to nn.CellList",
"nn.SmoothL1Loss": "maybe could convert to nn.SmoothL1Loss",
"nn.Tanh": "maybe could convert to nn.Tanh",
"nn.Upsample": "maybe could convert to P.ResizeBilinear",
"nn.L1Loss": "maybe could convert to nn.L1Loss",
"nn.Parameter": "maybe could convert to mindspore.Parameter",
"nn.ParameterList": "maybe could convert to mindspore.ParameterTuple",
"nn.Unfold": "maybe could convert to nn.Unfold",
"nn.PixelShuffle": "maybe could convert to P.DepthToSpace",
"F.adaptive_avg_pool2d": "maybe could convert to P.ReduceMean",
"F.conv2d": "maybe could convert to mindspore.ops.operations.Conv2D",
"F.dropout": "please use nn.Dropout in __init__()",
"F.interpolate": "maybe could convert to P.ResizeBilinear",
"torch.bmm": "maybe could convert to P.BatchMatMul",
"torch.cumsum": "maybe could convert to P.CumSum",
"F.relu": "maybe could convert to P.ReLU",
"F.pad": "maybe could convert to P.Pad",
"F.softmax": "maybe could convert to P.Softmax",
"torch.clamp": "maybe could convert to mindspore.ops.composite.clip_by_value",
"torch.eq": "maybe could convert to P.Equal",
"torch.load": "maybe could convert to mindspore.train.serialization.load_checkpoint",
"torch.matmul": "maybe could convert to P.MatMul",
"torch.max": "try to use P.ArgMaxWithValue, notice that two values are returned by P.ArgMaxWithValue",
"torch.mean": "maybe could convert to P.ReduceMean",
"torch.min": "try to use P.ArgMinWithValue, notice that two values are returned by P.ArgMinWithValue",
"torch.mm": "maybe could convert to P.MatMul",
"torch.mul": "maybe could convert to P.Mul",
"torch.norm": "maybe could convert to nn.Norm",
"torch.numel": "maybe could convert to P.Size",
"F.one_hot": "maybe could convert to P.OneHot",
"torch.ones_like": "maybe could convert to P.OnesLike",
"torch.randn": "maybe could convert to P.TruncatedNormal",
"torch.round": "maybe could convert to P.Round",
"torch.save": "maybe could convert to mindspore.train.serialization.save_checkpoint",
"torch.sigmoid": "maybe could convert to P.Sigmoid",
"torch.split": "maybe could convert to P.Split",
"torch.squeeze": "maybe could convert to P.Squeeze",
"torch.stack": "maybe could convert to P.Pack",
"torch.sum": "maybe could convert to mindspore.ops.operations.ReduceSum",
"torch.tanh": "maybe could convert to mindspore.ops.operations.Tanh",
"torch.tensor": "maybe could convert to mindspore.Tensor",
"torch.transpose": "maybe could convert to P.Transpose",
"torch.unsqueeze": "maybe could convert to P.ExpandDims",
"torch.zeros_like": "maybe could convert to P.ZerosLike",
".chunk": "maybe could convert to P.Split",
".fill_": "maybe could convert to P.Fill",
".float": "maybe could convert to P.Cast",
".mm": "maybe could convert to P.MatMul",
"mul": "maybe could convert to P.Mul",
".pow": "maybe could convert to P.Pow",
".round": "maybe could convert to P.Round",
".scatter": "maybe could convert to P.ScatterNd",
"sigmoid": "maybe could convert to nn.Sigmoid",
".sign": "maybe could convert to P.Sign",
".sqrt": "maybe could convert to P.Sqrt",
".sub": "maybe could convert to P.Sub",
".transpose": "maybe could convert to P.Transpose",
".unsqueeze": "maybe could convert to P.ExpandDims",
".zero_": "maybe could convert to P.ZerosLike",
"nn.AdaptiveAvgPool2d": "Maybe could convert to mindspore.ops.operations.ReduceMean.",
"nn.AvgPool1d": "Maybe could convert to mindspore.nn.AvgPool1d.",
"nn.ConvTranspose2d": "Maybe could convert to mindspore.nn.Conv2dTranspose.",
"nn.CrossEntropyLoss": "Maybe could convert to mindspore.nn.SoftmaxCrossEntropyWithLogits.",
"nn.Embedding": "Maybe could convert to mindspore.nn.Embedding.",
"nn.GroupNorm": "Maybe could convert to mindspore.nn.GroupNorm.",
"nn.MSELoss": "Maybe could convert to mindspore.nn.MSELoss.",
"nn.LSTM": "Maybe could convert to mindspore.nn.LSTM.",
"nn.LSTMCell": "Maybe could convert to mindspore.nn.LSTMCell.",
"nn.ModuleList": "Maybe could convert to mindspore.nn.CellList.",
"nn.SmoothL1Loss": "Maybe could convert to mindspore.nn.SmoothL1Loss.",
"nn.Tanh": "Maybe could convert to mindspore.nn.Tanh.",
"nn.Upsample": "Maybe could convert to mindspore.ops.operations.ResizeBilinear.",
"nn.L1Loss": "Maybe could convert to mindspore.nn.L1Loss.",
"nn.Parameter": "Maybe could convert to mindspore.Parameter.",
"nn.ParameterList": "Maybe could convert to mindspore.ParameterTuple.",
"nn.Unfold": "Maybe could convert to mindspore.nn.Unfold.",
"nn.PixelShuffle": "Maybe could convert to mindspore.ops.operations.DepthToSpace.",
"F.adaptive_avg_pool2d": "Maybe could convert to mindspore.ops.operations.ReduceMean.",
"F.conv2d": "Maybe could convert to mindspore.ops.operations.Conv2D.",
"F.dropout": "please use mindspore.nn.Dropout in __init__().",
"F.interpolate": "Maybe could convert to mindspore.ops.operations.ResizeBilinear.",
"F.one_hot": "Maybe could convert to mindspore.ops.operations.OneHot.",
"torch.bmm": "Maybe could convert to mindspore.ops.operations.BatchMatMul.",
"torch.cumsum": "Maybe could convert to mindspore.ops.operations.CumSum.",
"F.relu": "Maybe could convert to mindspore.ops.operations.ReLU.",
"F.pad": "Maybe could convert to mindspore.ops.operations.Pad.",
"F.softmax": "Maybe could convert to mindspore.ops.operations.Softmax.",
"torch.clamp": "Maybe could convert to mindspore.ops.composite.clip_by_value.",
"torch.eq": "Maybe could convert to mindspore.ops.operations.Equal.",
"torch.load": "Maybe could convert to mindspore.train.serialization.load_checkpoint.",
"torch.matmul": "Maybe could convert to mindspore.ops.operations.MatMul.",
"torch.max": "try to use P.ArgMaxWithValue, notice that two values are returned by mindspore.ops.operations."
"ArgMaxWithValue.",
"torch.mean": "Maybe could convert to mindspore.ops.operations.ReduceMean.",
"torch.min": "try to use P.ArgMinWithValue, notice that two values are returned by mindspore.ops.operations."
"ArgMinWithValue.",
"torch.mm": "Maybe could convert to mindspore.ops.operations.MatMul.",
"torch.mul": "Maybe could convert to mindspore.ops.operations.Mul.",
"torch.norm": "Maybe could convert to mindspore.nn.Norm.",
"torch.numel": "Maybe could convert to mindspore.ops.operations.Size.",
"torch.ones_like": "Maybe could convert to mindspore.ops.operations.OnesLike.",
"torch.randn": "Maybe could convert to mindspore.ops.operations.TruncatedNormal.",
"torch.round": "Maybe could convert to mindspore.ops.operations.Round.",
"torch.save": "Maybe could convert to mindspore.train.serialization.save_checkpoint.",
"torch.sigmoid": "Maybe could convert to mindspore.ops.operations.Sigmoid.",
"torch.split": "Maybe could convert to mindspore.ops.operations.Split.",
"torch.squeeze": "Maybe could convert to mindspore.ops.operations.Squeeze.",
"torch.stack": "Maybe could convert to mindspore.ops.operations.Pack.",
"torch.sum": "Maybe could convert to mindspore.ops.operations.ReduceSum.",
"torch.tanh": "Maybe could convert to mindspore.ops.operations.Tanh.",
"torch.tensor": "Maybe could convert to mindspore.Tensor.",
"torch.transpose": "Maybe could convert to mindspore.ops.operations.Transpose.",
"torch.unsqueeze": "Maybe could convert to mindspore.ops.operations.ExpandDims.",
"torch.zeros_like": "Maybe could convert to mindspore.ops.operations.ZerosLike.",
".chunk": "Maybe could convert to mindspore.ops.operations.Split.",
".fill_": "Maybe could convert to mindspore.ops.operations.Fill.",
".float": "Maybe could convert to mindspore.ops.operations.Cast.",
".mm": "Maybe could convert to mindspore.ops.operations.MatMul.",
".mul": "Maybe could convert to mindspore.ops.operations.Mul.",
".pow": "Maybe could convert to mindspore.ops.operations.Pow.",
".round": "Maybe could convert to mindspore.ops.operations.Round.",
".scatter": "Maybe could convert to mindspore.ops.operations.ScatterNd.",
".sigmoid": "Maybe could convert to mindspore.nn.Sigmoid.",
".sign": "Maybe could convert to mindspore.ops.operations.Sign.",
".sqrt": "Maybe could convert to mindspore.ops.operations.Sqrt.",
".sub": "Maybe could convert to mindspore.ops.operations.Sub.",
".transpose": "Maybe could convert to mindspore.ops.operations.Transpose.",
".unsqueeze": "Maybe could convert to mindspore.ops.operations.ExpandDims.",
".zero_": "Maybe could convert to mindspore.ops.operations.ZerosLike.",
}
NN_UNSUPPORTED_INFOS = {k: v for k, v in UNSUPPORTED_WARN_INFOS.items() if k.startswith('nn.')}
TORCH_NN_UNSUPPORTED_INFOS = {('torch.' + k): v for k, v in NN_UNSUPPORTED_INFOS.items()}
F_UNSUPPORTED_INFOS = {k: v for k, v in UNSUPPORTED_WARN_INFOS.items() if k.startswith('F.')}
NN_FUNCTIONAL_UNSUPPORTED_INFOS = {'nn.functional.' + k[len('F.'):]: v for k, v in F_UNSUPPORTED_INFOS.items()}
TORCH_NN_FUNCTIONAL_UNSUPPORTED_INFOS = {'torch.nn.functional.' + k[len('F.'):]: v for k, v in
F_UNSUPPORTED_INFOS.items()}
UNSUPPORTED_WARN_INFOS.update(TORCH_NN_UNSUPPORTED_INFOS)
UNSUPPORTED_WARN_INFOS.update(NN_FUNCTIONAL_UNSUPPORTED_INFOS)
UNSUPPORTED_WARN_INFOS.update(TORCH_NN_FUNCTIONAL_UNSUPPORTED_INFOS)
SUPPORTED_WARN_INFOS = {
"torch.eye": "Pay attention to use right mindspore data type.",
"nn.Linear": "Pay attention to reshape the input to 2 dims if it is 3 dims before, because MindSpore.nn.Dense only "
"support 2-dim input.",
".view": "Only float Tensor is supported in mindspore.ops.operations.Reshape.",
".reshape": "Only float Tensor is supported in mindspore.ops.operations.Reshape."
}
NN_SUPPORTED_INFOS = {k: v for k, v in SUPPORTED_WARN_INFOS.items() if k.startswith('nn.')}
TORCH_NN_SUPPORTED_INFOS = {('torch.' + k): v for k, v in NN_SUPPORTED_INFOS.items()}
SUPPORTED_WARN_INFOS.update(TORCH_NN_SUPPORTED_INFOS)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册