diff --git a/mindinsight/mindconverter/config.py b/mindinsight/mindconverter/config.py index c2fe7a8082deceece7a914ed2eb94841ac3e28f2..34de567d16013410563920148fa604b96b008c58 100644 --- a/mindinsight/mindconverter/config.py +++ b/mindinsight/mindconverter/config.py @@ -15,17 +15,18 @@ """API config""" import ast from collections import OrderedDict -from functools import partial +from importlib import import_module import json import os import pasta -from mindinsight.mindconverter.enums import RequriedType from mindinsight.mindconverter.common.log import logger -REQUIRED = RequriedType.REQUIRED.name -UNREQUIRED = RequriedType.UNREQUIRED.name + +REQUIRED = 'REQUIRED' +UNREQUIRED = 'UNREQUIRED' +FUNC_MODULE = 'mindinsight.mindconverter.funcs' class APIPt: @@ -250,88 +251,65 @@ class MappingHelper: return expr_ms -def gen_explicit_map_nn_sequential(_, args_pt): +def get_ms_api(ms_api_info): """ - Generate explicit_map for nn.Sequential. + Get APIMs instance from ms_api_info. Args: - args_pt (dict): Args for APIPt. + ms_api_info (list): info for create an APIMs instance, the first value in list is name for APIMs, the second(if + provided) is params for APIMs, the third(if provided) is p_attrs for APIMs. Returns: - dict, map between frames. + APIMs, instance of APIMs parsed from given info. """ - args = args_pt['*args'] - return {"*args": "[{}]".format(args)} + ms_name = ms_api_info[0] + ms_params = ms_api_info[1] if len(ms_api_info) >= 2 else None + ms_p_attrs = set(ms_api_info[2]) if len(ms_api_info) >= 3 else None + ms_api = APIMs(name=ms_name, params=ms_params, p_attrs=ms_p_attrs) + return ms_api -def gen_explicit_map_nn_maxpool2d(params_pt, args_pt): +def get_pt_api(pt_api_info): """ - Generate explicit_map for nn.MaxPool2d. + Get APIPt instance from pt_api_info. Args: - params_pt (dict): Params for APIPt. - args_pt (dict): Args for APIPt. + pt_api_info (list): info for create an APIMs instance, the first value in list is name for APIPt, the second(if + provided) is params for APIPt. Returns: - dict, map between frames. - """ - if 'padding' in args_pt: - padding = args_pt['padding'] - else: - padding = params_pt['padding'] - if padding.strip() in ("0", "(0,0)", "(0, 0)"): - pad_mode = "'valid'" - else: - pad_mode = "'same'" - return {"pad_mode": pad_mode} - - -def gen_explicit_map_f_max_pool2d(params_pt, args_pt): + APIMs, instance of APIMs parsed from given info. """ - Generate explicit_map for F.MaxPool2d. + pt_name = pt_api_info[0] + pt_params = pt_api_info[1] if len(pt_api_info) >= 2 else None + pt_api = APIPt(name=pt_name, params=pt_params) + return pt_api - Args: - params_pt (dict): Params for APIPt. - args_pt (dict): Args for APIPt. - Returns: - dict, map between frames. +def get_mapping_from_file(path): """ - if 'padding' in args_pt: - padding = args_pt['padding'] - else: - padding = params_pt['padding'] - if padding.strip() in ("0", "(0,0)", "(0, 0)"): - padding = "'valid'" - else: - padding = "'same'" - return {"padding": padding} - - -def gen_explicit_map_one_delta(params_pt, args_pt, k_ms, k_pt): - """ - Generate explicit_map for which include mapping relationship is `1 - k_ms = k_pt`. + Parse mapping info from given file. Args: - params_pt (dict): Params for APIPt. - args_pt (dict): Args for APIPt. + path (str): The file path. Returns: - dict, map between frames. + dict, key is op name, value is a relevant instance of MappingHelper. """ - value = args_pt[k_pt] if k_pt in args_pt else params_pt[k_pt] - value = value.strip() - - def is_number(string): - try: - float(string) - return True - except ValueError: - return False - - if is_number(value): - return {k_ms: str(1 - float(value))} - return {k_ms: "1.0 - " + value} + mapping_info_d = load_json_file(path) + parse_mapping_dict = {} + for key, value in mapping_info_d.items(): + ms_api_info = value.pop('ms_api') + ms_api = get_ms_api(ms_api_info) + pt_api_info = value.pop('pt_api') + pt_api = get_pt_api(pt_api_info) + gen_explicit_map = value.get('gen_explicit_map') + if gen_explicit_map: + module_name = import_module(FUNC_MODULE) + value['gen_explicit_map'] = getattr(module_name, gen_explicit_map) + + parse_mapping_dict.update({key: MappingHelper(**dict(ms_api=ms_api, pt_api=pt_api), **value)}) + return parse_mapping_dict def load_json_file(file_path): @@ -350,244 +328,38 @@ def load_json_file(file_path): # ---------------------------- mappings ---------------------------- -NN_MAPPING = { - 'nn.Sequential': MappingHelper(**{"ms_api": APIMs('nn.SequentialCell', OrderedDict([('*args', REQUIRED)])), - "pt_api": APIPt('nn.Sequential', OrderedDict([('*args', REQUIRED)])), - "gen_explicit_map": gen_explicit_map_nn_sequential, - "export_key": False - }), - 'nn.Conv2d': MappingHelper(**{"ms_api": APIMs('nn.Conv2d', OrderedDict(in_channels=REQUIRED, - out_channels=REQUIRED, - kernel_size=REQUIRED, - stride=1, - pad_mode='same', - padding=0, - dilation=1, - group=1, - has_bias=False, - weight_init='normal', - bias_init='zeros')), - "pt_api": APIPt('nn.Conv2d', OrderedDict(in_channels=REQUIRED, - out_channels=REQUIRED, - kernel_size=REQUIRED, - stride=1, - padding=0, - dilation=1, - groups=1, - bias=True, - padding_mode='zeros')), - "ms2pt_mapping": {'in_channels': 'in_channels', - 'out_channels': 'out_channels', - 'kernel_size': 'kernel_size', - 'stride': 'stride', - 'padding': 'padding', - 'dilation': 'dilation', - 'group': 'groups', - 'has_bias': 'bias'}, - "gen_explicit_map": (lambda params_pt, args_pt: {"pad_mode": "'pad'"}) - }), - 'nn.BatchNorm2d': MappingHelper(**{"ms_api": APIMs('nn.BatchNorm2d', OrderedDict(num_features=REQUIRED, - eps=1e-5, - momentum=0.9, - affine=True, - gamma_init='ones', - beta_init='zeros', - moving_mean_init='zeros', - moving_var_init='ones', - use_batch_statistics=True)), - "pt_api": APIPt('nn.BatchNorm2d', OrderedDict(num_features=REQUIRED, - eps=1e-5, - momentum=0.1, - affine=True, - track_running_stats=True)), - "ms2pt_mapping": {"num_features": "num_features", - "eps": "eps", - "affine": "affine", - "use_batch_statistics": "track_running_stats"}, - "gen_explicit_map": partial(gen_explicit_map_one_delta, - k_ms="momentum", k_pt="momentum") - }), - 'nn.ReLU': MappingHelper(**{"ms_api": APIMs('nn.ReLU', OrderedDict()), - "pt_api": APIPt('nn.ReLU', OrderedDict(inplace=False)), - "ms2pt_mapping": {}}), - 'nn.ReLU6': MappingHelper(**{"ms_api": APIMs('nn.ReLU6', OrderedDict()), - "pt_api": APIPt('nn.ReLU6', OrderedDict(inplace=False)), - "ms2pt_mapping": {}}), - 'nn.Linear': MappingHelper(**{"ms_api": APIMs('nn.Dense', OrderedDict(in_channels=REQUIRED, - out_channels=REQUIRED, - weight_init='normal', - bias_init='zeros', - has_bias=True, - activation=None)), - "pt_api": APIPt('nn.Linear', OrderedDict(in_features=REQUIRED, - out_features=REQUIRED, - bias=True)), - "ms2pt_mapping": {"in_channels": "in_features", - "out_channels": "out_features", - "has_bias": "bias"} - }), - 'nn.MaxPool2d': MappingHelper(**{"ms_api": APIMs('nn.MaxPool2d', OrderedDict(kernel_size=1, - stride=1, - pad_mode="valid")), - "pt_api": APIPt('nn.MaxPool2d', OrderedDict(kernel_size=REQUIRED, - stride=None, - padding=0, - dilation=1, - return_indices=False, - ceil_mode="False")), - "ms2pt_mapping": {"kernel_size": "kernel_size", - "stride": "stride"}, - "gen_explicit_map": gen_explicit_map_nn_maxpool2d - }), - 'nn.AvgPool2d': MappingHelper(**{"ms_api": APIMs('nn.AvgPool2d', OrderedDict(kernel_size=1, - stride=1, - pad_mode="valid")), - "pt_api": APIPt('nn.AvgPool2d', OrderedDict(kernel_size=REQUIRED, - stride=None, - padding=0, - dilation=1, - return_indices=False, - ceil_mode="False")), - "ms2pt_mapping": {"kernel_size": "kernel_size", - "stride": "stride"}, - "gen_explicit_map": gen_explicit_map_nn_maxpool2d - }), - 'nn.Dropout': MappingHelper(**{"ms_api": APIMs('nn.Dropout', OrderedDict(keep_prob=0.5, - seed0=0, - seed1=0, - dtype="mstype.float32")), - "pt_api": APIPt('nn.Dropout', OrderedDict(p=0.5, - inplace=False)), - "ms2pt_mapping": {"keep_prob": "p"}, - "gen_explicit_map": partial(gen_explicit_map_one_delta, - k_ms="keep_prob", k_pt="p") - }) -} +NN_MAPPING_PATH = os.path.realpath(os.path.join(os.path.dirname(__file__), 'mappings/nn_mappings.json')) +NN_MAPPING = get_mapping_from_file(NN_MAPPING_PATH) +# update to add key with full api_name, which starts with 'torch.nn.' NN_MAPPING.update({"torch." + k: v for k, v in NN_MAPPING.items()}) +F_MAPPING_PATH = os.path.realpath(os.path.join(os.path.dirname(__file__), 'mappings/f_mappings.json')) +F_MAPPING = get_mapping_from_file(F_MAPPING_PATH) +# update to add key starts with 'nn.functional.' +NN_FUNCTIONAL_D = {"nn.functional." + k[len('F.'):]: v for k, v in F_MAPPING.items()} +# update to add key starts with 'torch.nn.functiona.l' +TORCH_NN_FUNCTIONAL_D = {"torch.nn.functional." + k[len('F.'):]: v for k, v in F_MAPPING.items()} +F_MAPPING.update(NN_FUNCTIONAL_D) +F_MAPPING.update(TORCH_NN_FUNCTIONAL_D) -F_MAPPING = { - 'F.relu': MappingHelper(**{"ms_api": APIMs('P.ReLU', OrderedDict(input=REQUIRED)), - "pt_api": APIPt('F.relu', OrderedDict(input=REQUIRED, inplace=False)), - "ms2pt_mapping": {"input": "input"}, - }), - 'F.relu6': MappingHelper(**{"ms_api": APIMs('P.ReLU6', OrderedDict(input=REQUIRED)), - "pt_api": APIPt('F.relu6', OrderedDict(input=REQUIRED, inplace=False)), - "ms2pt_mapping": {"input": "input"}, - }), - 'F.max_pool2d': MappingHelper(**{"ms_api": APIMs('P.MaxPool', OrderedDict(ksize=1, - strides=1, - padding="valid", - input=REQUIRED), - p_attrs={"ksize", "strides", "padding"}), - "pt_api": APIPt('F.max_pool2d', OrderedDict(input=REQUIRED, - kernel_size=REQUIRED, - stride=None, - padding=0, - dilation=1, - ceil_mode=False, - return_indices=False)), - "ms2pt_mapping": {"ksize": "kernel_size", - "strides": "stride", - "input": "input", - }, - "gen_explicit_map": gen_explicit_map_f_max_pool2d - }), - 'F.avg_pool2d': MappingHelper(**{"ms_api": APIMs('P.AvgPool', OrderedDict(ksize=1, - strides=1, - padding="valid", - input=REQUIRED), - p_attrs={"ksize", "strides", "padding"}), - "pt_api": APIPt('F.avg_pool2d', OrderedDict(input=REQUIRED, - kernel_size=REQUIRED, - stride=None, - padding=0, - dilation=1, - ceil_mode=False, - return_indices=False)), - "ms2pt_mapping": {"ksize": "kernel_size", - "strides": "stride", - "input": "input", - }, - "gen_explicit_map": gen_explicit_map_f_max_pool2d - }), -} -nn_functional_d = {"nn.functional." + k[2:]: v for k, v in F_MAPPING.items()} -torch_nn_functional_d = {"torch.nn.functional." + k[2:]: v for k, v in F_MAPPING.items()} -F_MAPPING.update(nn_functional_d) -F_MAPPING.update(torch_nn_functional_d) - - -TORCH_DOT_MAPPING = { - 'torch.flatten': MappingHelper(**{"ms_api": APIMs('P.Flatten', OrderedDict(input=REQUIRED)), - "pt_api": APIPt('torch.flatten', OrderedDict(input=REQUIRED, - start_dim=0, - end_dim=-1)), - "ms2pt_mapping": {"input": "input"} - }), - 'torch.cat': MappingHelper(**{"ms_api": APIMs('P.Concat', - OrderedDict(axis=0, input=REQUIRED), - p_attrs={"axis"}), - "pt_api": APIPt('torch.flatten', OrderedDict(tensors=REQUIRED, dim=0, out=None)), - "ms2pt_mapping": {"input": "tensors", - "axis": "dim"} - }), -} - - -TENSOR_DOT_MAPPING = { - '.view': MappingHelper(**{"ms_api": APIMs('P.Reshape', OrderedDict(x=REQUIRED, shape=REQUIRED)), - "pt_api": APIPt('.view', OrderedDict([('*shape', REQUIRED)])), - "ms2pt_mapping": {"x": "call_name"}, - "gen_explicit_map": (lambda params_pt, args_pt: {"shape": "(" + args_pt["*shape"] + ",)"}) - }), - '.size': MappingHelper(**{"ms_api": APIMs('P.Shape', OrderedDict(x=REQUIRED)), - "pt_api": APIPt('.size', OrderedDict([('idx', REQUIRED)])), - "ms2pt_mapping": {"x": "call_name"} - }), - '.flatten': MappingHelper(**{"ms_api": APIMs('P.Flatten', OrderedDict(input=REQUIRED)), - "pt_api": APIPt('.flatten', OrderedDict(start_dim=0, - end_dim=-1)), - "ms2pt_mapping": {"input": "call_name"} - }), - '.reshape': MappingHelper(**{"ms_api": APIMs('P.Reshape', OrderedDict(x=REQUIRED, shape=REQUIRED)), - "pt_api": APIPt('.reshape', OrderedDict([('*shape', REQUIRED)])), - "ms2pt_mapping": {"x": "call_name"}, - "gen_explicit_map": ( - lambda params_pt, args_pt: {"shape": "(" + args_pt["*shape"] + ",)"}) - }), - '.mean': MappingHelper(**{"ms_api": APIMs('P.ReduceMean', OrderedDict(keep_dims=False, - input=REQUIRED, - axis=()), - p_attrs={"keep_dims"}), - "pt_api": APIPt('.mean', OrderedDict(dim=None, - keepdim=False)), - "ms2pt_mapping": {"keep_dims": "keepdim", - "axis": "dim", - "input": "call_name"}, - }), - '.squeeze': MappingHelper(**{"ms_api": APIMs('P.ReduceMean', OrderedDict(input=REQUIRED, - axis=()), - p_attrs={"axis"}), - "pt_api": APIPt('.squeeze', OrderedDict(dim=None)), - "ms2pt_mapping": {"axis": "dim", - "input": "call_name"}, - }), -} +TORCH_DOT_MAPPING_PATH = os.path.realpath(os.path.join(os.path.dirname(__file__), 'mappings/torch_dot_mappings.json')) +TORCH_DOT_MAPPING = get_mapping_from_file(TORCH_DOT_MAPPING_PATH) +TENSOR_DOT_MAPPING_PATH = os.path.realpath(os.path.join(os.path.dirname(__file__), 'mappings/tensor_dot_mappings.json')) +TENSOR_DOT_MAPPING = get_mapping_from_file(TENSOR_DOT_MAPPING_PATH) ALL_MAPPING = {**NN_MAPPING, **F_MAPPING, **TORCH_DOT_MAPPING, **TENSOR_DOT_MAPPING} # ---------------------------- api list support or not support ---------------------------- -NN_LIST_PATH = os.path.realpath(os.path.join(os.path.dirname(__file__), 'nn_list.json')) +NN_LIST_PATH = os.path.realpath(os.path.join(os.path.dirname(__file__), 'ops', 'nn_list.json')) NN_LIST = load_json_file(NN_LIST_PATH) NN_LIST += ["torch." + name for name in NN_LIST] NN_SUPPORTED = [x for x in NN_LIST if x in ALL_MAPPING] NN_UNSUPPORTED = [x for x in NN_LIST if x not in ALL_MAPPING] -F_LIST_PATH = os.path.realpath(os.path.join(os.path.dirname(__file__), 'f_list.json')) +F_LIST_PATH = os.path.realpath(os.path.join(os.path.dirname(__file__), 'ops', 'f_list.json')) F_LIST = load_json_file(F_LIST_PATH) F_LIST += ["F." + name[len("torch.nn.functional."):] for name in F_LIST] + \ [name[len("torch."):] for name in F_LIST] @@ -595,7 +367,7 @@ F_SUPPORTED = [x for x in F_LIST if x in ALL_MAPPING] F_UNSUPPORTED = [x for x in F_LIST if x not in ALL_MAPPING] -TORCH_DOT_LIST_PATH = os.path.realpath(os.path.join(os.path.dirname(__file__), 'torch_dot_list.json')) +TORCH_DOT_LIST_PATH = os.path.realpath(os.path.join(os.path.dirname(__file__), 'ops', 'torch_dot_list.json')) TORCH_DOT_LIST = load_json_file(TORCH_DOT_LIST_PATH) @@ -603,7 +375,7 @@ TORCH_DOT_SUPPORTED = [x for x in TORCH_DOT_LIST if x in ALL_MAPPING] TORCH_DOT_UNSUPPORTED = [x for x in TORCH_DOT_LIST if x not in ALL_MAPPING] -TENSOR_DOT_LIST_PATH = os.path.realpath(os.path.join(os.path.dirname(__file__), 'tensor_dot_list.json')) +TENSOR_DOT_LIST_PATH = os.path.realpath(os.path.join(os.path.dirname(__file__), 'ops', 'tensor_dot_list.json')) TENSOR_DOT_LIST = load_json_file(TENSOR_DOT_LIST_PATH) @@ -620,5 +392,5 @@ ALL_UNSUPPORTED = NN_UNSUPPORTED + F_UNSUPPORTED + TORCH_DOT_UNSUPPORTED + TENSO UNSUPPORTED_WARN_INFOS = { "nn.AdaptiveAvgPool2d": "maybe could convert to P.ReduceMean", "F.adaptive_avg_pool2d": "maybe could convert to P.ReduceMean", - "F.dropout": "please use nn.Dropout in __init__()", + "F.dropout": "please use nn.Dropout in __init__()" } diff --git a/mindinsight/mindconverter/enums.py b/mindinsight/mindconverter/enums.py deleted file mode 100644 index 3a8c0f34438468892b7c4b43f39cebc7b57635af..0000000000000000000000000000000000000000 --- a/mindinsight/mindconverter/enums.py +++ /dev/null @@ -1,22 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""Enums.""" -from enum import Enum - - -class RequriedType(Enum): - """If param is required""" - REQUIRED = 1 - UNREQUIRED = 2 diff --git a/mindinsight/mindconverter/funcs.py b/mindinsight/mindconverter/funcs.py new file mode 100644 index 0000000000000000000000000000000000000000..eb2b57cfbb108f72c7221755a0263b9e79814464 --- /dev/null +++ b/mindinsight/mindconverter/funcs.py @@ -0,0 +1,106 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless REQUIRED by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""funcs for gen_explicit_map""" +from functools import partial + + +def gen_explicit_map_f_max_pool2d(params_pt, args_pt): + """ + Generate explicit_map for F.MaxPool2d. + + Args: + params_pt (dict): Params for APIPt. + args_pt (dict): Args for APIPt. + + Returns: + dict, map between frames. + """ + if 'padding' in args_pt: + padding = args_pt['padding'] + else: + padding = params_pt['padding'] + if padding.strip() in ("0", "(0,0)", "(0, 0)"): + padding = "'valid'" + else: + padding = "'same'" + return {"padding": padding} + + +def gen_explicit_map_nn_sequential(_, args_pt): + """ + Generate explicit_map for nn.Sequential. + + Args: + args_pt (dict): Args for APIPt. + + Returns: + dict, map between frames. + """ + args = args_pt['*args'] + return {"*args": "[{}]".format(args)} + + +def gen_explicit_map_one_delta(params_pt, args_pt, k_ms, k_pt): + """ + Generate explicit_map for which include mapping relationship is `1 - k_ms = k_pt`. + + Args: + params_pt (dict): Params for APIPt. + args_pt (dict): Args for APIPt. + + Returns: + dict, map between frames. + """ + value = args_pt[k_pt] if k_pt in args_pt else params_pt[k_pt] + value = value.strip() + + def is_number(string): + try: + float(string) + return True + except ValueError: + return False + + if is_number(value): + return {k_ms: str(1 - float(value))} + return {k_ms: "1.0 - " + value} + + +def gen_explicit_map_nn_maxpool2d(params_pt, args_pt): + """ + Generate explicit_map for nn.MaxPool2d. + + Args: + params_pt (dict): Params for APIPt. + args_pt (dict): Args for APIPt. + + Returns: + dict, map between frames. + """ + if 'padding' in args_pt: + padding = args_pt['padding'] + else: + padding = params_pt['padding'] + if padding.strip() in ("0", "(0,0)", "(0, 0)"): + pad_mode = "'valid'" + else: + pad_mode = "'same'" + return {"pad_mode": pad_mode} + +tensor_dot_view_gen_explicit_map = lambda params_pt, args_pt: {"shape": "(" + args_pt["*shape"] + ",)"} +tensor_dot_reshape_gen_explicit_map = lambda params_pt, args_pt: {"shape": "(" + args_pt["*shape"] + ",)"} +nn_conv2d_gen_explicit_map = lambda params_pt, args_pt: {"pad_mode": "'pad'"} +nn_batchnorm2d_gen_explicit_map = partial(gen_explicit_map_one_delta, k_ms="momentum", k_pt="momentum") +nn_dropout_gen_explicit_map = partial(gen_explicit_map_one_delta, k_ms="keep_prob", k_pt="p") diff --git a/mindinsight/mindconverter/mappings/f_mappings.json b/mindinsight/mindconverter/mappings/f_mappings.json new file mode 100644 index 0000000000000000000000000000000000000000..a342721cfad0f7dce1924a4433f0345ceb9fe2cb --- /dev/null +++ b/mindinsight/mindconverter/mappings/f_mappings.json @@ -0,0 +1,108 @@ +{ + "F.avg_pool2d": { + "ms_api": [ + "P.AvgPool", + { + "ksize": 1, + "strides": 1, + "padding": "valid", + "input": "REQUIRED" + }, + [ + "ksize", + "strides", + "padding" + ] + ], + "pt_api": [ + "F.avg_pool2d", + { + "input": "REQUIRED", + "kernel_size": "REQUIRED", + "stride": null, + "padding": 0, + "dilation": 1, + "ceil_mode": false, + "return_indices": false + } + ], + "ms2pt_mapping": { + "ksize": "kernel_size", + "strides": "stride", + "input": "input" + }, + "gen_explicit_map": "gen_explicit_map_f_max_pool2d" + }, + "F.max_pool2d": { + "ms_api": [ + "P.MaxPool", + { + "ksize": 1, + "strides": 1, + "padding": "valid", + "input": "REQUIRED" + }, + [ + "ksize", + "strides", + "padding" + ] + ], + "pt_api": [ + "F.max_pool2d", + { + "input": "REQUIRED", + "kernel_size": "REQUIRED", + "stride": null, + "padding": 0, + "dilation": 1, + "ceil_mode": false, + "return_indices": false + } + ], + "ms2pt_mapping": { + "ksize": "kernel_size", + "strides": "stride", + "input": "input" + }, + "gen_explicit_map": "gen_explicit_map_f_max_pool2d" + }, + "F.relu": { + "ms_api": [ + "P.ReLU", + { + "input": "REQUIRED" + } + ], + "pt_api": [ + "F.relu", + { + "input": "REQUIRED", + "inplace": false + } + ], + "ms2pt_mapping": { + "input": "input" + }, + "gen_explicit_map": null + }, + "F.relu6": { + "ms_api": [ + "P.ReLU6", + { + "input": "REQUIRED" + } + ], + "pt_api": [ + "F.relu6", + { + "input": "REQUIRED", + "inplace": false + } + ], + "ms2pt_mapping": { + "input": "input" + }, + "gen_explicit_map": null + } +} \ No newline at end of file diff --git a/mindinsight/mindconverter/mappings/nn_mappings.json b/mindinsight/mindconverter/mappings/nn_mappings.json new file mode 100644 index 0000000000000000000000000000000000000000..1f50bb076f646540e0d5876ea70d4b313c914e01 --- /dev/null +++ b/mindinsight/mindconverter/mappings/nn_mappings.json @@ -0,0 +1,220 @@ +{ + "nn.Dropout": { + "ms_api": [ + "nn.Dropout", + { + "keep_prob": 0.5, + "seed0": 0, + "seed1": 0, + "dtype": "mstype.float32" + } + ], + "pt_api": [ + "nn.Dropout", + { + "p": 0.5, + "inplace": false + } + ], + "ms2pt_mapping": { + "keep_prob": "p" + }, + "gen_explicit_map": "nn_dropout_gen_explicit_map" + }, + "nn.AvgPool2d": { + "ms_api": [ + "nn.AvgPool2d", + { + "kernel_size": 1, + "stride": 1, + "pad_mode": "valid" + } + ], + "pt_api": [ + "nn.AvgPool2d", + { + "kernel_size": "REQUIRED", + "stride": null, + "padding": 0, + "dilation": 1, + "return_indices": false, + "ceil_mode": "False" + } + ], + "ms2pt_mapping": { + "kernel_size": "kernel_size", + "stride": "stride" + }, + "gen_explicit_map": "gen_explicit_map_nn_maxpool2d" + }, + "nn.MaxPool2d": { + "ms_api": [ + "nn.MaxPool2d", + { + "kernel_size": 1, + "stride": 1, + "pad_mode": "valid" + } + ], + "pt_api": [ + "nn.MaxPool2d", + { + "kernel_size": "REQUIRED", + "stride": null, + "padding": 0, + "dilation": 1, + "return_indices": false, + "ceil_mode": "False" + } + ], + "ms2pt_mapping": { + "kernel_size": "kernel_size", + "stride": "stride" + }, + "gen_explicit_map": "gen_explicit_map_nn_maxpool2d" + }, + "nn.Linear": { + "ms_api": [ + "nn.Dense", + { + "in_channels": "REQUIRED", + "out_channels": "REQUIRED", + "weight_init": "normal", + "bias_init": "zeros", + "has_bias": true, + "activation": null + } + ], + "pt_api": [ + "nn.Linear", + { + "in_features": "REQUIRED", + "out_features": "REQUIRED", + "bias": true + } + ], + "ms2pt_mapping": { + "in_channels": "in_features", + "out_channels": "out_features", + "has_bias": "bias" + } + }, + "nn.ReLU6": { + "ms_api": [ + "nn.ReLU6", + {} + ], + "pt_api": [ + "nn.ReLU6", + { + "inplace": false + } + ], + "ms2pt_mapping": {} + }, + "nn.ReLU": { + "ms_api": [ + "nn.ReLU", + {} + ], + "pt_api": [ + "F.relu", + { + "inplace": false + } + ], + "ms2pt_mapping": {} + }, + "nn.BatchNorm2d": { + "ms_api": [ + "nn.BatchNorm2d", + { + "num_features": "REQUIRED", + "eps": 1e-05, + "momentum": 0.9, + "affine": true, + "gamma_init": "ones", + "beta_init": "zeros", + "moving_mean_init": "zeros", + "moving_var_init": "ones", + "use_batch_statistics": true + } + ], + "pt_api": [ + "nn.BatchNorm2d", + { + "num_features": "REQUIRED", + "eps": 1e-05, + "momentum": 0.1, + "affine": true, + "track_running_stats": true + } + ], + "ms2pt_mapping": { + "num_features": "num_features", + "eps": "eps", + "affine": "affine", + "use_batch_statistics": "track_running_stats" + }, + "gen_explicit_map": "nn_batchnorm2d_gen_explicit_map" + }, + "nn.Conv2d": { + "ms_api": [ + "nn.Conv2d", + { + "in_channels": "REQUIRED", + "out_channels": "REQUIRED", + "kernel_size": "REQUIRED", + "stride": 1, + "pad_mode": "same", + "padding": 0, + "dilation": 1, + "group": 1, + "has_bias": false, + "weight_init": "normal", + "bias_init": "zeros" + } + ], + "pt_api": [ + "nn.Conv2d", + { + "in_channels": "REQUIRED", + "out_channels": "REQUIRED", + "kernel_size": "REQUIRED", + "stride": 1, + "padding": 0, + "dilation": 1, + "groups": 1, + "bias": true, + "padding_mode": "zeros" + } + ], + "ms2pt_mapping": { + "in_channels": "in_channels", + "out_channels": "out_channels", + "kernel_size": "kernel_size", + "stride": "stride", + "padding": "padding", + "dilation": "dilation", + "group": "groups", + "has_bias": "bias" + }, + "gen_explicit_map": "nn_conv2d_gen_explicit_map" + }, + "nn.Sequential": { + "ms_api": [ + "nn.SequentialCell", + { + "*args": " REQUIRED" + } + ], + "pt_api": [ + "nn.Sequential", + { + "*args": " REQUIRED" + } + ], + "export_key": false, + "gen_explicit_map": "gen_explicit_map_nn_sequential" + } +} \ No newline at end of file diff --git a/mindinsight/mindconverter/mappings/tensor_dot_mappings.json b/mindinsight/mindconverter/mappings/tensor_dot_mappings.json new file mode 100644 index 0000000000000000000000000000000000000000..51d6475a9c94032c60b29c89a0b298f254f9092f --- /dev/null +++ b/mindinsight/mindconverter/mappings/tensor_dot_mappings.json @@ -0,0 +1,119 @@ +{ + ".view": { + "ms_api": [ + "P.Reshape", + { + "x": "REQUIRED", + "shape": "REQUIRED" + } + ], + "pt_api": [ + ".view", + { + "*shape": "REQUIRED" + } + ], + "ms2pt_mapping": { + "x": "call_name" + }, + "gen_explicit_map": "tensor_dot_view_gen_explicit_map" + }, + ".size": { + "ms_api": [ + "P.Shape", + { + "x": "REQUIRED" + } + ], + "pt_api": [ + ".size", + { + "idx": "REQUIRED" + } + ], + "ms2pt_mapping": { + "x": "call_name" + } + }, + ".flatten": { + "ms_api": [ + "P.Flatten", + { + "input": "REQUIRED" + } + ], + "pt_api": [ + ".flatten", + { + "start_dim": 0, + "end_dim": -1 + } + ], + "ms2pt_mapping": { + "input": "call_name" + } + }, + ".reshape": { + "ms_api": [ + "P.Reshape", + { + "x": "REQUIRED", + "shape": "REQUIRED" + } + ], + "pt_api": [ + ".reshape", + { + "*shape": "REQUIRED" + } + ], + "ms2pt_mapping": { + "x": "call_name" + }, + "gen_explicit_map": "tensor_dot_reshape_gen_explicit_map" + }, + ".mean": { + "ms_api": [ + "P.ReduceMean", + { + "keep_dims": false, + "input": "REQUIRED", + "axis": [] + } + ], + "pt_api": [ + ".mean", + { + "dim": null, + "keepdim": false + } + ], + "ms2pt_mapping": { + "keep_dims": "keepdim", + "axis": "dim", + "input": "call_name" + } + }, + ".squeeze": { + "ms_api": [ + "P.ReduceMean", + { + "input": "REQUIRED", + "axis": [] + }, + [ + "axis" + ] + ], + "pt_api": [ + ".squeeze", + { + "dim": null + } + ], + "ms2pt_mapping": { + "axis": "dim", + "input": "call_name" + } + } +} \ No newline at end of file diff --git a/mindinsight/mindconverter/mappings/torch_dot_mappings.json b/mindinsight/mindconverter/mappings/torch_dot_mappings.json new file mode 100644 index 0000000000000000000000000000000000000000..8482efcad4bcabd161d7005307be1af9a50c5c8a --- /dev/null +++ b/mindinsight/mindconverter/mappings/torch_dot_mappings.json @@ -0,0 +1,45 @@ +{ + "torch.flatten": { + "ms_api": [ + "P.Flatten", + { + "input": "REQUIRED" + } + ], + "pt_api": [ + "torch.flatten", + { + "input": "REQUIRED", + "start_dim": 0, + "end_dim": -1 + } + ], + "ms2pt_mapping": { + "input": "input" + } + }, + "torch.cat": { + "ms_api": [ + "P.Concat", + { + "axis": 0, + "input": "REQUIRED" + }, + [ + "axis" + ] + ], + "pt_api": [ + "torch.cat", + { + "tensors": "REQUIRED", + "dim": 0, + "out": null + } + ], + "ms2pt_mapping": { + "input": "tensors", + "axis": "dim" + } + } +} \ No newline at end of file diff --git a/mindinsight/mindconverter/f_list.json b/mindinsight/mindconverter/ops/f_list.json similarity index 100% rename from mindinsight/mindconverter/f_list.json rename to mindinsight/mindconverter/ops/f_list.json diff --git a/mindinsight/mindconverter/nn_list.json b/mindinsight/mindconverter/ops/nn_list.json similarity index 100% rename from mindinsight/mindconverter/nn_list.json rename to mindinsight/mindconverter/ops/nn_list.json diff --git a/mindinsight/mindconverter/tensor_dot_list.json b/mindinsight/mindconverter/ops/tensor_dot_list.json similarity index 100% rename from mindinsight/mindconverter/tensor_dot_list.json rename to mindinsight/mindconverter/ops/tensor_dot_list.json diff --git a/mindinsight/mindconverter/torch_dot_list.json b/mindinsight/mindconverter/ops/torch_dot_list.json similarity index 100% rename from mindinsight/mindconverter/torch_dot_list.json rename to mindinsight/mindconverter/ops/torch_dot_list.json