From 8c0104f42057881ee6932fc3958fe92223a1fea2 Mon Sep 17 00:00:00 2001 From: quyongxiu Date: Sat, 6 Jun 2020 15:17:02 +0800 Subject: [PATCH] add a new op mapping add op mappings fix pylint fix mistake { as [ add tests add test for eye and revise func and mappings revise mapping and add test fix pylint fix pylint --- mindinsight/mindconverter/config.py | 6 +- mindinsight/mindconverter/funcs.py | 25 +- .../mindconverter/mappings/f_mappings.json | 45 +++ .../mindconverter/mappings/nn_mappings.json | 113 +++++++ .../mappings/tensor_dot_mappings.json | 38 +++ .../mappings/torch_dot_mappings.json | 201 +++++++++++ tests/ut/mindconverter/test_converter.py | 317 ++++++++++++++++++ 7 files changed, 742 insertions(+), 3 deletions(-) diff --git a/mindinsight/mindconverter/config.py b/mindinsight/mindconverter/config.py index 34de567..9f07853 100644 --- a/mindinsight/mindconverter/config.py +++ b/mindinsight/mindconverter/config.py @@ -337,7 +337,7 @@ F_MAPPING_PATH = os.path.realpath(os.path.join(os.path.dirname(__file__), 'mappi F_MAPPING = get_mapping_from_file(F_MAPPING_PATH) # update to add key starts with 'nn.functional.' NN_FUNCTIONAL_D = {"nn.functional." + k[len('F.'):]: v for k, v in F_MAPPING.items()} -# update to add key starts with 'torch.nn.functiona.l' +# update to add key starts with 'torch.nn.functional.' TORCH_NN_FUNCTIONAL_D = {"torch.nn.functional." + k[len('F.'):]: v for k, v in F_MAPPING.items()} F_MAPPING.update(NN_FUNCTIONAL_D) F_MAPPING.update(TORCH_NN_FUNCTIONAL_D) @@ -392,5 +392,7 @@ ALL_UNSUPPORTED = NN_UNSUPPORTED + F_UNSUPPORTED + TORCH_DOT_UNSUPPORTED + TENSO UNSUPPORTED_WARN_INFOS = { "nn.AdaptiveAvgPool2d": "maybe could convert to P.ReduceMean", "F.adaptive_avg_pool2d": "maybe could convert to P.ReduceMean", - "F.dropout": "please use nn.Dropout in __init__()" + "F.dropout": "please use nn.Dropout in __init__()", + "torch.max": "try to use P.ArgMaxWithValue, notice that two values are returned by P.ArgMaxWithValue", + "torch.min": "try to use P.ArgMinWithValue, notice that two values are returned by P.ArgMinWithValue" } diff --git a/mindinsight/mindconverter/funcs.py b/mindinsight/mindconverter/funcs.py index eb2b57c..e5e5081 100644 --- a/mindinsight/mindconverter/funcs.py +++ b/mindinsight/mindconverter/funcs.py @@ -99,8 +99,31 @@ def gen_explicit_map_nn_maxpool2d(params_pt, args_pt): pad_mode = "'same'" return {"pad_mode": pad_mode} -tensor_dot_view_gen_explicit_map = lambda params_pt, args_pt: {"shape": "(" + args_pt["*shape"] + ",)"} + +def torch_dot_eye_gen_explicit_map(_, args_pt): + """ + Generate explicit_map for torch.eye. + + Args: + args_pt (dict): Args for APIPt. + + Returns: + dict, map between frames. + """ + explicit_map = {'t': 'mindspore.int32'} + if args_pt.get('m'): + explicit_map.update({'m': args_pt.get('m')}) + else: + explicit_map.update({'m': args_pt.get('n')}) + return explicit_map + +tensor_dot_permute_gen_explicit_map = lambda params_pt, args_pt: {"input_perm": "(" + args_pt["*dIms"] + ",)"} +tensor_dot_repeat_gen_explicit_map = lambda params_pt, args_pt: {"multiples": "(" + args_pt["*sizes"] + ",)"} tensor_dot_reshape_gen_explicit_map = lambda params_pt, args_pt: {"shape": "(" + args_pt["*shape"] + ",)"} +tensor_dot_view_gen_explicit_map = lambda params_pt, args_pt: {"shape": "(" + args_pt["*shape"] + ",)"} nn_conv2d_gen_explicit_map = lambda params_pt, args_pt: {"pad_mode": "'pad'"} nn_batchnorm2d_gen_explicit_map = partial(gen_explicit_map_one_delta, k_ms="momentum", k_pt="momentum") +nn_batchnorm1d_gen_explicit_map = nn_batchnorm2d_gen_explicit_map nn_dropout_gen_explicit_map = partial(gen_explicit_map_one_delta, k_ms="keep_prob", k_pt="p") +torch_dot_add_gen_explicit_map = lambda params_pt, args_pt:\ + {"input_y": (args_pt['value'] + '*' + args_pt["alpha"]) if args_pt.get("alpha") else args_pt['value']} diff --git a/mindinsight/mindconverter/mappings/f_mappings.json b/mindinsight/mindconverter/mappings/f_mappings.json index a342721..9430217 100644 --- a/mindinsight/mindconverter/mappings/f_mappings.json +++ b/mindinsight/mindconverter/mappings/f_mappings.json @@ -104,5 +104,50 @@ "input": "input" }, "gen_explicit_map": null + }, + "F.normalize": { + "ms_api": [ + "P.L2Normalize", + { + "axis": 0, + "epsilon": 0.0001, + "input_x": "REQUIRED" + }, + [ + "axis", + "epsilon" + ] + ], + "pt_api": [ + "F.normalize", + { + "input": "REQUIRED", + "p": 2, + "dim": 1, + "eps": 1e-12 + } + ], + "ms2pt_mapping": { + "input_x": "input", + "epsilon": "eps", + "axis": "dim" + } + }, + "F.sigmoid": { + "ms_api": [ + "P.Sigmoid", + { + "input_x": "REQUIRED" + } + ], + "pt_api": [ + "F.sigmoid", + { + "input": "REQUIRED" + } + ], + "ms2pt_mapping": { + "input_x": "input" + } } } \ No newline at end of file diff --git a/mindinsight/mindconverter/mappings/nn_mappings.json b/mindinsight/mindconverter/mappings/nn_mappings.json index 1f50bb0..b3d6714 100644 --- a/mindinsight/mindconverter/mappings/nn_mappings.json +++ b/mindinsight/mindconverter/mappings/nn_mappings.json @@ -216,5 +216,118 @@ ], "export_key": false, "gen_explicit_map": "gen_explicit_map_nn_sequential" + }, + "nn.BatchNorm1d": { + "ms_api": [ + "nn.BatchNorm1d", + { + "num_features": "REQUIRED", + "eps": 1e-05, + "momentum": 0.9, + "affine": true, + "gamma_init": "ones", + "beta_init": "zeros", + "moving_mean_init": "zeros", + "moving_var_init": "ones", + "use_batch_statistics": true + } + ], + "pt_api": [ + "nn.BatchNorm1d", + { + "num_features": "REQUIRED", + "eps": 1e-05, + "momentum": 0.1, + "affine": true, + "track_running_stats": true + } + ], + "ms2pt_mapping": { + "num_features": "num_features", + "eps": "eps", + "affine": "affine", + "use_batch_statistics": "track_running_stats" + }, + "gen_explicit_map": "nn_batchnorm1d_gen_explicit_map" + }, + "nn.LayerNorm": { + "ms_api": [ + "nn.LayerNorm", + { + "normalized_shape": "REQUIRED", + "begin_norm_axis": -1, + "begin_params_axis": -1, + "gamma_init": "ones", + "beta_init": "zeros", + "epsilon": 1e-07 + } + ], + "pt_api": [ + "nn.LayerNorm", + { + "normalized_shape": "REQUIRED", + "eps": 1e-05, + "elementwise_affine": true + } + ], + "ms2pt_mapping": { + "normalized_shape": "normalized_shape", + "epsilon": "eps" + } + }, + "nn.LeakyReLU": { + "ms_api": [ + "nn.LeakyReLU", + { + "alpha": 0.2 + } + ], + "pt_api": [ + "nn.LeakyReLU", + { + "negative_slope": 0.2, + "inplace": false + } + ], + "ms2pt_mapping": { + "alpha": "negative_slope" + } + }, + "nn.PReLU": { + "ms_api": [ + "nn.PReLU", + { + "channel": 1, + "w": 0.25 + } + ], + "pt_api": [ + "nn.PReLU", + { + "num_parameters": 1, + "init": 0.25 + } + ], + "ms2pt_mapping": { + "channel": "num_parameters", + "w": "init" + } + }, + "nn.Softmax": { + "ms_api": [ + "nn.Softmax", + { + "axis": -1 + } + ], + "pt_api": [ + "nn.Softmax", + { + "dim": "REQUIRED" + } + ], + "ms2pt_mapping": { + "axis": "dim" + } } } \ No newline at end of file diff --git a/mindinsight/mindconverter/mappings/tensor_dot_mappings.json b/mindinsight/mindconverter/mappings/tensor_dot_mappings.json index 51d6475..c96144d 100644 --- a/mindinsight/mindconverter/mappings/tensor_dot_mappings.json +++ b/mindinsight/mindconverter/mappings/tensor_dot_mappings.json @@ -115,5 +115,43 @@ "axis": "dim", "input": "call_name" } + }, + ".repeat": { + "ms_api": [ + "P.Tile", + { + "input_x": "REQUIRED", + "multiples": "REQUIRED" + } + ], + "pt_api": [ + ".repeat", + { + "*sizes": "REQUIRED" + } + ], + "ms2pt_mapping": { + "input_x": "call_name" + }, + "gen_explicit_map": "tensor_dot_repeat_gen_explicit_map" + }, + ".permute": { + "ms_api": [ + "P.Transpose", + { + "input_x": "REQUIRED", + "input_perm": "REQUIRED" + } + ], + "pt_api": [ + ".permute", + { + "*dIms": "REQUIRED" + } + ], + "ms2pt_mapping": { + "input_x": "call_name" + }, + "gen_explicit_map": "tensor_dot_permute_gen_explicit_map" } } \ No newline at end of file diff --git a/mindinsight/mindconverter/mappings/torch_dot_mappings.json b/mindinsight/mindconverter/mappings/torch_dot_mappings.json index 8482efc..ee99b10 100644 --- a/mindinsight/mindconverter/mappings/torch_dot_mappings.json +++ b/mindinsight/mindconverter/mappings/torch_dot_mappings.json @@ -41,5 +41,206 @@ "input": "tensors", "axis": "dim" } + }, + "torch.abs": { + "ms_api": [ + "P.Abs", + { + "input_x": "REQUIRED" + } + ], + "pt_api": [ + ".abs", + { + "input": "REQUIRED" + } + ], + "ms2pt_mapping": { + "input_x": "input" + } + }, + "torch.acos": { + "ms_api": [ + "P.ACos", + { + "input_x": "REQUIRED" + } + ], + "pt_api": [ + ".acos", + { + "input": "REQUIRED" + } + ], + "ms2pt_mapping": { + "input_x": "input" + } + }, + "torch.cos": { + "ms_api": [ + "P.Cos", + { + "input_x": "REQUIRED" + } + ], + "pt_api": [ + ".cos", + { + "input": "REQUIRED" + } + ], + "ms2pt_mapping": { + "input_x": "input" + } + }, + "torch.exp": { + "ms_api": [ + "P.Exp", + { + "input_x": "REQUIRED" + } + ], + "pt_api": [ + ".exp", + { + "input": "REQUIRED" + } + ], + "ms2pt_mapping": { + "input_x": "input" + } + }, + "torch.log": { + "ms_api": [ + "P.Log", + { + "input_x": "REQUIRED" + } + ], + "pt_api": [ + ".log", + { + "input": "REQUIRED" + } + ], + "ms2pt_mapping": { + "input_x": "input" + } + }, + "torch.pow": { + "ms_api": [ + "P.Pow", + { + "input_x": "REQUIRED", + "input_y": "REQUIRED" + } + ], + "pt_api": [ + ".pow", + { + "input": "REQUIRED", + "exponent": "REQUIRED" + } + ], + "ms2pt_mapping": { + "input_x": "input", + "input_y": "exponent" + } + }, + "torch.div": { + "ms_api": [ + "P.Div", + { + "input_x": "REQUIRED", + "input_y": "REQUIRED" + } + ], + "pt_api": [ + ".div", + { + "input": "REQUIRED", + "other": "REQUIRED" + } + ], + "ms2pt_mapping": { + "input_x": "input", + "input_y": "other" + } + }, + "torch.sin": { + "ms_api": [ + "P.Sin", + { + "input_x": "REQUIRED" + } + ], + "pt_api": [ + ".sin", + { + "input": "REQUIRED" + } + ], + "ms2pt_mapping": { + "input_x": "input" + } + }, + "torch.sqrt": { + "ms_api": [ + "P.Sqrt", + { + "input_x": "REQUIRED" + } + ], + "pt_api": [ + ".sqrt", + { + "input": "REQUIRED" + } + ], + "ms2pt_mapping": { + "input_x": "input" + } + }, + "torch.add": { + "ms_api": [ + "P.TensorAdd", + { + "input_x": "REQUIRED", + "input_y": "REQUIRED" + } + ], + "pt_api": [ + ".add", + { + "input": "REQUIRED", + "value": "REQUIRED", + "alpha": 1 + } + ], + "ms2pt_mapping": { + "input_x": "input" + }, + "gen_explicit_map": "torch_dot_add_gen_explicit_map" + }, + "torch.eye": { + "ms_api": [ + "P.Eye", + { + "n": "REQUIRED", + "m": "REQUIRED", + "t": "REQUIRED" + } + ], + "pt_api": [ + ".eye", + { + "n": "REQUIRED", + "m": "REQUIRED" + } + ], + "ms2pt_mapping": { + "n": "n" + }, + "gen_explicit_map": "torch_dot_eye_gen_explicit_map" } } \ No newline at end of file diff --git a/tests/ut/mindconverter/test_converter.py b/tests/ut/mindconverter/test_converter.py index 9674cfc..44ad641 100644 --- a/tests/ut/mindconverter/test_converter.py +++ b/tests/ut/mindconverter/test_converter.py @@ -14,6 +14,7 @@ # ============================================================================ """Test Converter""" from mindinsight.mindconverter.converter import Converter +from mindinsight.mindconverter.config import NN_MAPPING class TestConverter: @@ -82,3 +83,319 @@ class TestConverter: result = self.converter_ins.find_right_parentheses(code, left_index) assert_index = len(code) - 1 assert result == assert_index + + # test convert_api with nn ops + def test_convert_api_nn_layernorm(self): + """Test convert_api function work ok when convert api nn.LayerNorm""" + code = """ + def __init__(self, num_classes=1000): + self.features = nn.SequentialCell([ + nn.LayerNorm((5, 10, 10), elementwise_affine=False), + nn.ReLU(inplace=False) + ]) + """ + api_name = 'nn.LayerNorm' + start = code.find(api_name) + + layer_norm_info = NN_MAPPING.get(api_name) + expected_ms_api_name = 'nn.LayerNorm' + + epsilon = layer_norm_info.pt_api.params.get('eps') + + replaced_code, new_start = self.converter_ins.convert_api(code, start, api_name) + assert replaced_code == code.replace('nn.LayerNorm((5, 10, 10), elementwise_affine=False)', + '{}(normalized_shape=(5, 10, 10), epsilon={})'.format( + expected_ms_api_name, epsilon)) + assert new_start == start + len(expected_ms_api_name) + + def test_convert_api_nn_leaky_relu(self): + """Test convert_api function work ok when convert api nn.LeakyReLU""" + code = """ + def __init__(self, num_classes=1000): + self.features = nn.SequentialCell([ + nn.LayerNorm((5, 10, 10), elementwise_affine=False), + nn.LeakyReLU(0.3)]) + """ + api_name = 'nn.LeakyReLU' + start = code.find(api_name) + expected_ms_api_name = 'nn.LeakyReLU' + + replaced_code, new_start = self.converter_ins.convert_api(code, start, api_name) + assert replaced_code == code.replace('nn.LeakyReLU(0.3)', + '{}(alpha=0.3)'.format(expected_ms_api_name)) + assert new_start == start + len(expected_ms_api_name) + + def test_convert_api_nn_prelu(self): + """Test convert_api function work ok when convert api nn.PReLU""" + code = """ + input = torch.randn(2, 3, 5) + nn.PReLU()(input) + + """ + api_name = 'nn.PReLU' + start = code.find(api_name) + expected_ms_api_name = 'nn.PReLU' + replaced_code, new_start = self.converter_ins.convert_api(code, start, api_name) + assert replaced_code == code.replace('nn.PReLU()(input)', + '{}()(input)'.format(expected_ms_api_name)) + assert new_start == start + len(expected_ms_api_name) + + def test_convert_api_nn_softmax(self): + """Test convert_api function work ok when convert api nn.Softmax""" + code = """ + nn.Softmax(dim=1)(input) + """ + api_name = 'nn.Softmax' + expected_ms_api_name = 'nn.Softmax' + start = code.find(api_name) + + replaced_code, new_start = self.converter_ins.convert_api(code, start, api_name) + assert replaced_code == code.replace('nn.Softmax(dim=1)(input)', + '{}(axis=1)(input)'.format(expected_ms_api_name)) + assert new_start == start + len(expected_ms_api_name) + + # test convert_api with torch dot ops + def test_convert_api_torch_dot_abs(self): + """Test convert_api function work ok when convert api torch.abs""" + code = """ + torch.abs(input) + """ + api_name = 'torch.abs' + start = code.find(api_name) + expected_ms_api_name = 'P.Abs' + + replaced_code, new_start = self.converter_ins.convert_api(code, start, api_name) + assert replaced_code == code.replace('torch.abs(input)', + '{}()(input)'.format(expected_ms_api_name)) + assert new_start == start + len(expected_ms_api_name) + + def test_convert_api_torch_dot_acos(self): + """Test convert_api function work ok when convert api torch.acos""" + code = """ + torch.acos(input) + """ + api_name = 'torch.acos' + start = code.find(api_name) + expected_ms_api_name = 'P.ACos' + + replaced_code, new_start = self.converter_ins.convert_api(code, start, api_name) + assert replaced_code == code.replace('torch.acos(input)', + '{}()(input)'.format(expected_ms_api_name)) + assert new_start == start + len(expected_ms_api_name) + + def test_convert_api_torch_dot_cos(self): + """Test convert_api function work ok when convert api torch.cos""" + code = """ + torch.cos(input) + """ + api_name = 'torch.cos' + expected_ms_api_name = 'P.Cos' + start = code.find(api_name) + + replaced_code, new_start = self.converter_ins.convert_api(code, start, api_name) + assert replaced_code == code.replace('torch.cos(input)', + '{}()(input)'.format(expected_ms_api_name)) + assert new_start == start + len(expected_ms_api_name) + + def test_convert_api_torch_dot_exp(self): + """Test convert_api function work ok when convert api torch.exp""" + code = """ + torch.exp(input) + """ + api_name = 'torch.exp' + expected_ms_api_name = 'P.Exp' + start = code.find(api_name) + + replaced_code, new_start = self.converter_ins.convert_api(code, start, api_name) + assert replaced_code == code.replace('torch.exp(input)', + '{}()(input)'.format(expected_ms_api_name)) + assert new_start == start + len(expected_ms_api_name) + + def test_convert_api_torch_dot_log(self): + """Test convert_api function work ok when convert api torch.log""" + code = """ + torch.log(input) + """ + api_name = 'torch.log' + expected_ms_api_name = 'P.Log' + start = code.find(api_name) + + replaced_code, new_start = self.converter_ins.convert_api(code, start, api_name) + assert replaced_code == code.replace('torch.log(input)', + '{}()(input)'.format(expected_ms_api_name)) + assert new_start == start + len(expected_ms_api_name) + + def test_convert_api_torch_dot_pow(self): + """Test convert_api function work ok when convert api torch.pow""" + code = """ + torch.pow(a, exp) + """ + api_name = 'torch.pow' + expected_ms_api_name = 'P.Pow' + start = code.find(api_name) + + replaced_code, new_start = self.converter_ins.convert_api(code, start, api_name) + assert replaced_code == code.replace('torch.pow(a, exp)', + '{}()(a, exp)'.format(expected_ms_api_name)) + assert new_start == start + len(expected_ms_api_name) + + def test_convert_api_torch_dot_div(self): + """Test convert_api function work ok when convert api torch.div""" + code = """ + input = torch.randn(5) + other = torch.randn(5) + torch.div(input, other) + """ + api_name = 'torch.div' + expected_ms_api_name = 'P.Div' + start = code.find(api_name) + + replaced_code, new_start = self.converter_ins.convert_api(code, start, api_name) + + assert replaced_code == code.replace('torch.div(input, other)', + '{}()(input, other)'.format(expected_ms_api_name)) + assert new_start == start + len(expected_ms_api_name) + + def test_convert_api_torch_dot_sin(self): + """Test convert_api function work ok when convert api torch.sin""" + code = """ + torch.sin(input) + """ + api_name = 'torch.sin' + expected_ms_api_name = 'P.Sin' + start = code.find(api_name) + + replaced_code, new_start = self.converter_ins.convert_api(code, start, api_name) + assert replaced_code == code.replace('torch.sin(input)', + '{}()(input)'.format(expected_ms_api_name)) + assert new_start == start + len(expected_ms_api_name) + + def test_convert_api_torch_dot_sqrt(self): + """Test convert_api function work ok when convert api torch.sqrt""" + code = """ + torch.sqrt(input) + """ + api_name = 'torch.sqrt' + expected_ms_api_name = 'P.Sqrt' + start = code.find(api_name) + + replaced_code, new_start = self.converter_ins.convert_api(code, start, api_name) + assert replaced_code == code.replace('torch.sqrt(input)', + '{}()(input)'.format(expected_ms_api_name)) + assert new_start == start + len(expected_ms_api_name) + + def test_convert_api_torch_dot_eye_with_n(self): + """Test convert_api function work ok when convert api torch.eye""" + code = """ + torch.eye(3) + """ + api_name = 'torch.eye' + expected_ms_api_name = 'P.Eye' + start = code.find(api_name) + + replaced_code, new_start = self.converter_ins.convert_api(code, start, api_name) + assert replaced_code == code.replace('torch.eye(3)', + '{}()(3, 3, mindspore.int32)'.format(expected_ms_api_name)) + assert new_start == start + len(expected_ms_api_name) + + def test_convert_api_torch_dot_eye_with_m(self): + """Test convert_api function work ok when convert api torch.eye""" + code = """ + torch.eye(3, 4) + """ + api_name = 'torch.eye' + expected_ms_api_name = 'P.Eye' + start = code.find(api_name) + + replaced_code, new_start = self.converter_ins.convert_api(code, start, api_name) + assert replaced_code == code.replace('torch.eye(3, 4)', + '{}()(3, 4, mindspore.int32)'.format(expected_ms_api_name)) + assert new_start == start + len(expected_ms_api_name) + + def test_convert_api_torch_dot_add_with_alpha_default(self): + """Test convert_api function work ok when convert api torch.add""" + code = """ + torch.add(input, value) + """ + api_name = 'torch.add' + expected_ms_api_name = 'P.TensorAdd' + start = code.find(api_name) + + replaced_code, new_start = self.converter_ins.convert_api(code, start, api_name) + assert replaced_code == code.replace('torch.add(input, value)', + '{}()(input, value)'.format(expected_ms_api_name)) + assert new_start == start + len(expected_ms_api_name) + + def test_convert_api_torch_dot_add_with_alpha_not_default(self): + """Test convert_api function work ok when convert api torch.add""" + code = """ + torch.add(input, value, 3) + """ + api_name = 'torch.add' + expected_ms_api_name = 'P.TensorAdd' + start = code.find(api_name) + + replaced_code, new_start = self.converter_ins.convert_api(code, start, api_name) + assert replaced_code == code.replace('torch.add(input, value, 3)', + '{}()(input, value*3)'.format(expected_ms_api_name)) + assert new_start == start + len(expected_ms_api_name) + + # test convert_api with F ops + def test_convert_api_f_normalize(self): + """Test convert_api function work ok when convert api F.normalize""" + code = """ + input = torch.randn(2, 3, 5) + F.normalize(input) + """ + api_name = 'F.normalize' + start = code.find(api_name) + expected_ms_api_name = 'P.L2Normalize' + + replaced_code, new_start = self.converter_ins.convert_api(code, start, api_name) + assert replaced_code == code.replace('F.normalize(input)', + '{}(1, 1e-12)(input)'.format(expected_ms_api_name)) + assert new_start == start + len(expected_ms_api_name) + + def test_convert_api_f_sigmoid(self): + """Test convert_api function work ok when convert api F.sigmoid""" + code = """ + input = torch.randn(2, 3, 5) + F.sigmoid(input) + """ + api_name = 'F.sigmoid' + start = code.find(api_name) + expected_ms_api_name = 'P.Sigmoid' + + replaced_code, new_start = self.converter_ins.convert_api(code, start, api_name) + assert replaced_code == code.replace('F.sigmoid(input)', + '{}()(input)'.format(expected_ms_api_name)) + assert new_start == start + len(expected_ms_api_name) + + # test convert_api with tensor dot ops + def test_convert_api_tensor_dot_repeat(self): + """Test convert_api function work ok when convert api .repeat""" + code = """ + x.repeat(4, 2) + """ + api_name = '.repeat' + start = code.find(api_name) + expected_ms_api_name = 'P.Tile' + + replaced_code, new_start = self.converter_ins.convert_api(code, start, api_name) + assert replaced_code == code.replace('x.repeat(4, 2)', + '{}()(x, {})'.format(expected_ms_api_name, '(4, 2,)')) + assert new_start == start + len(expected_ms_api_name) + + def test_convert_api_tensor_dot_permute(self): + """Test convert_api function work ok when convert api .permute""" + code = """ + x.permute(2, 0, 1) + """ + api_name = '.permute' + start = code.find(api_name) + expected_ms_api_name = 'P.Transpose' + + replaced_code, new_start = self.converter_ins.convert_api(code, start, api_name) + assert replaced_code == code.replace('x.permute(2, 0, 1)', + '{}()(x, (2, 0, 1,))'.format(expected_ms_api_name)) + assert new_start == start + len(expected_ms_api_name) -- GitLab