diff --git a/mindinsight/mindconverter/config.py b/mindinsight/mindconverter/config.py index b6843efb0f768d02ab1953a535a98e45e0db7594..840e4cfde0c1d1b40f823559e6d8bd26acd80dff 100644 --- a/mindinsight/mindconverter/config.py +++ b/mindinsight/mindconverter/config.py @@ -71,9 +71,9 @@ class APIPt: or the given args_str not valid. """ # expr is REQUIRED to meet (**) format - if not (len(args_str) >= 2 and args_str[0] == "(" and args_str[-1] == ")"): - raise ValueError('[{}] is think as args str, it should start with "(" and end with ")"'.format(args_str)) - + if not (len(args_str) >= 2 and args_str[0] == "(" and args_str.strip()[-1] == ")"): + raise ValueError('"{}" is think as args string, it should start with "(" and end with ")" without ' + 'considering spaces'.format(args_str)) try: ast_node = ast.parse("whatever_call_name" + args_str) call_node = ast_node.body[0].value diff --git a/mindinsight/mindconverter/funcs.py b/mindinsight/mindconverter/funcs.py index e5e5081f3119a9031c4e3dc186f435010da4c8d3..ffc8cd7fa8ea62d224fa2a5d2a6a2b375f9a7e1d 100644 --- a/mindinsight/mindconverter/funcs.py +++ b/mindinsight/mindconverter/funcs.py @@ -35,7 +35,14 @@ def gen_explicit_map_f_max_pool2d(params_pt, args_pt): padding = "'valid'" else: padding = "'same'" - return {"padding": padding} + + if 'stride' in args_pt: + strides = args_pt['stride'] + else: + strides = args_pt['kernel_size'] + + return {"padding": padding, + "strides": strides} def gen_explicit_map_nn_sequential(_, args_pt): @@ -97,7 +104,14 @@ def gen_explicit_map_nn_maxpool2d(params_pt, args_pt): pad_mode = "'valid'" else: pad_mode = "'same'" - return {"pad_mode": pad_mode} + + if 'stride' in args_pt: + stride = args_pt['stride'] + else: + stride = args_pt['kernel_size'] + + return {"pad_mode": pad_mode, + "stride": stride} def torch_dot_eye_gen_explicit_map(_, args_pt): diff --git a/mindinsight/mindconverter/mappings/f_mappings.json b/mindinsight/mindconverter/mappings/f_mappings.json index 94302177f5c53594be86b6aab1bca21b6bcbd266..86ed58d3553f4da581356fcf99134bc2ef1e97f1 100644 --- a/mindinsight/mindconverter/mappings/f_mappings.json +++ b/mindinsight/mindconverter/mappings/f_mappings.json @@ -21,14 +21,13 @@ "kernel_size": "REQUIRED", "stride": null, "padding": 0, - "dilation": 1, "ceil_mode": false, - "return_indices": false + "count_include_pad": true, + "divisor_override": null } ], "ms2pt_mapping": { "ksize": "kernel_size", - "strides": "stride", "input": "input" }, "gen_explicit_map": "gen_explicit_map_f_max_pool2d" @@ -62,7 +61,6 @@ ], "ms2pt_mapping": { "ksize": "kernel_size", - "strides": "stride", "input": "input" }, "gen_explicit_map": "gen_explicit_map_f_max_pool2d" diff --git a/mindinsight/mindconverter/mappings/nn_mappings.json b/mindinsight/mindconverter/mappings/nn_mappings.json index b3d67149dd511411c0aa3d284afa589c139def30..6c892612ef118c3c16389e64f71e289d7f5721ce 100644 --- a/mindinsight/mindconverter/mappings/nn_mappings.json +++ b/mindinsight/mindconverter/mappings/nn_mappings.json @@ -16,9 +16,7 @@ "inplace": false } ], - "ms2pt_mapping": { - "keep_prob": "p" - }, + "ms2pt_mapping": {}, "gen_explicit_map": "nn_dropout_gen_explicit_map" }, "nn.AvgPool2d": { @@ -36,14 +34,13 @@ "kernel_size": "REQUIRED", "stride": null, "padding": 0, - "dilation": 1, - "return_indices": false, - "ceil_mode": "False" + "ceil_mode": false, + "count_include_pad": true, + "divisor_override": null } ], "ms2pt_mapping": { - "kernel_size": "kernel_size", - "stride": "stride" + "kernel_size": "kernel_size" }, "gen_explicit_map": "gen_explicit_map_nn_maxpool2d" }, @@ -68,8 +65,7 @@ } ], "ms2pt_mapping": { - "kernel_size": "kernel_size", - "stride": "stride" + "kernel_size": "kernel_size" }, "gen_explicit_map": "gen_explicit_map_nn_maxpool2d" }, diff --git a/tests/ut/mindconverter/test_converter.py b/tests/ut/mindconverter/test_converter.py index c55ba1f2e817044aaed7385ae0cdb66b4db344c8..eda165323eb8e72f06bf8da651040ec9037d2e58 100644 --- a/tests/ut/mindconverter/test_converter.py +++ b/tests/ut/mindconverter/test_converter.py @@ -64,6 +64,15 @@ class TestConverter: assert replaced_code == code.replace('nn.Softmax(dim=1)', '{}(axis=1)'.format(expected_ms_api_name)) + def test_convert_api_nn_dropout(self): + """Test convert_api function work ok when convert api nn.Dropout""" + code = """nn.Dropout(0.3)""" + expected_ms_api_name = 'nn.Dropout' + + replaced_code = self.converter_ins.convert_api(code) + assert replaced_code == code.replace('nn.Dropout(0.3)', + "{}(keep_prob=0.7)".format(expected_ms_api_name)) + # test convert_api with torch dot ops def test_convert_api_torch_dot_abs(self): """Test convert_api function work ok when convert api torch.abs""" @@ -202,6 +211,33 @@ class TestConverter: assert replaced_code == code.replace('F.sigmoid(input)', '{}()(input)'.format(expected_ms_api_name)) + def test_convert_api_f_max_pool2d(self): + """Test convert_api function work ok when convert api F.max_pool2d""" + code = """F.max_pool2d(out, 2)""" + expected_ms_api_name = 'P.MaxPool' + + replaced_code = self.converter_ins.convert_api(code) + assert replaced_code == code.replace('F.max_pool2d(out, 2)', + "{}(2, 2, 'valid')(out)".format(expected_ms_api_name)) + + def test_convert_api_f_avg_pool2d_without_strides(self): + """Test convert_api function work ok when convert api F.avg_pool2d""" + code = """F.avg_pool2d(out, 2)""" + expected_ms_api_name = 'P.AvgPool' + + replaced_code = self.converter_ins.convert_api(code) + assert replaced_code == code.replace('F.avg_pool2d(out, 2)', + "{}(2, 2, 'valid')(out)".format(expected_ms_api_name)) + + def test_convert_api_f_avg_pool2d_with_strides(self): + """Test convert_api function work ok when convert api F.avg_pool2d""" + code = """F.avg_pool2d(out, 2, 3)""" + expected_ms_api_name = 'P.AvgPool' + + replaced_code = self.converter_ins.convert_api(code) + assert replaced_code == code.replace('F.avg_pool2d(out, 2, 3)', + "{}(2, 3, 'valid')(out)".format(expected_ms_api_name)) + # test convert_api with tensor dot ops def test_convert_api_tensor_dot_repeat(self): """Test convert_api function work ok when convert api .repeat""" @@ -216,7 +252,6 @@ class TestConverter: """Test convert_api function work ok when convert api .permute""" code = "x.permute(2, 0, 1)" expected_ms_api_name = 'P.Transpose' - replaced_code = self.converter_ins.convert_api(code) assert replaced_code == code.replace('x.permute(2, 0, 1)', '{}()(x, (2, 0, 1,))'.format(expected_ms_api_name))