From 9a93920fd08d2a060e2b56a439ffbfcd9d67e5f6 Mon Sep 17 00:00:00 2001 From: quyongxiu1 Date: Wed, 17 Jun 2020 11:20:46 +0800 Subject: [PATCH] fix some mapping relationships and add tests to verify the relationship --- mindinsight/mindconverter/config.py | 6 +-- mindinsight/mindconverter/funcs.py | 18 ++++++++- .../mindconverter/mappings/f_mappings.json | 6 +-- .../mindconverter/mappings/nn_mappings.json | 16 +++----- tests/ut/mindconverter/test_converter.py | 37 ++++++++++++++++++- 5 files changed, 63 insertions(+), 20 deletions(-) diff --git a/mindinsight/mindconverter/config.py b/mindinsight/mindconverter/config.py index b6843ef..840e4cf 100644 --- a/mindinsight/mindconverter/config.py +++ b/mindinsight/mindconverter/config.py @@ -71,9 +71,9 @@ class APIPt: or the given args_str not valid. """ # expr is REQUIRED to meet (**) format - if not (len(args_str) >= 2 and args_str[0] == "(" and args_str[-1] == ")"): - raise ValueError('[{}] is think as args str, it should start with "(" and end with ")"'.format(args_str)) - + if not (len(args_str) >= 2 and args_str[0] == "(" and args_str.strip()[-1] == ")"): + raise ValueError('"{}" is think as args string, it should start with "(" and end with ")" without ' + 'considering spaces'.format(args_str)) try: ast_node = ast.parse("whatever_call_name" + args_str) call_node = ast_node.body[0].value diff --git a/mindinsight/mindconverter/funcs.py b/mindinsight/mindconverter/funcs.py index e5e5081..ffc8cd7 100644 --- a/mindinsight/mindconverter/funcs.py +++ b/mindinsight/mindconverter/funcs.py @@ -35,7 +35,14 @@ def gen_explicit_map_f_max_pool2d(params_pt, args_pt): padding = "'valid'" else: padding = "'same'" - return {"padding": padding} + + if 'stride' in args_pt: + strides = args_pt['stride'] + else: + strides = args_pt['kernel_size'] + + return {"padding": padding, + "strides": strides} def gen_explicit_map_nn_sequential(_, args_pt): @@ -97,7 +104,14 @@ def gen_explicit_map_nn_maxpool2d(params_pt, args_pt): pad_mode = "'valid'" else: pad_mode = "'same'" - return {"pad_mode": pad_mode} + + if 'stride' in args_pt: + stride = args_pt['stride'] + else: + stride = args_pt['kernel_size'] + + return {"pad_mode": pad_mode, + "stride": stride} def torch_dot_eye_gen_explicit_map(_, args_pt): diff --git a/mindinsight/mindconverter/mappings/f_mappings.json b/mindinsight/mindconverter/mappings/f_mappings.json index 9430217..86ed58d 100644 --- a/mindinsight/mindconverter/mappings/f_mappings.json +++ b/mindinsight/mindconverter/mappings/f_mappings.json @@ -21,14 +21,13 @@ "kernel_size": "REQUIRED", "stride": null, "padding": 0, - "dilation": 1, "ceil_mode": false, - "return_indices": false + "count_include_pad": true, + "divisor_override": null } ], "ms2pt_mapping": { "ksize": "kernel_size", - "strides": "stride", "input": "input" }, "gen_explicit_map": "gen_explicit_map_f_max_pool2d" @@ -62,7 +61,6 @@ ], "ms2pt_mapping": { "ksize": "kernel_size", - "strides": "stride", "input": "input" }, "gen_explicit_map": "gen_explicit_map_f_max_pool2d" diff --git a/mindinsight/mindconverter/mappings/nn_mappings.json b/mindinsight/mindconverter/mappings/nn_mappings.json index b3d6714..6c89261 100644 --- a/mindinsight/mindconverter/mappings/nn_mappings.json +++ b/mindinsight/mindconverter/mappings/nn_mappings.json @@ -16,9 +16,7 @@ "inplace": false } ], - "ms2pt_mapping": { - "keep_prob": "p" - }, + "ms2pt_mapping": {}, "gen_explicit_map": "nn_dropout_gen_explicit_map" }, "nn.AvgPool2d": { @@ -36,14 +34,13 @@ "kernel_size": "REQUIRED", "stride": null, "padding": 0, - "dilation": 1, - "return_indices": false, - "ceil_mode": "False" + "ceil_mode": false, + "count_include_pad": true, + "divisor_override": null } ], "ms2pt_mapping": { - "kernel_size": "kernel_size", - "stride": "stride" + "kernel_size": "kernel_size" }, "gen_explicit_map": "gen_explicit_map_nn_maxpool2d" }, @@ -68,8 +65,7 @@ } ], "ms2pt_mapping": { - "kernel_size": "kernel_size", - "stride": "stride" + "kernel_size": "kernel_size" }, "gen_explicit_map": "gen_explicit_map_nn_maxpool2d" }, diff --git a/tests/ut/mindconverter/test_converter.py b/tests/ut/mindconverter/test_converter.py index c55ba1f..eda1653 100644 --- a/tests/ut/mindconverter/test_converter.py +++ b/tests/ut/mindconverter/test_converter.py @@ -64,6 +64,15 @@ class TestConverter: assert replaced_code == code.replace('nn.Softmax(dim=1)', '{}(axis=1)'.format(expected_ms_api_name)) + def test_convert_api_nn_dropout(self): + """Test convert_api function work ok when convert api nn.Dropout""" + code = """nn.Dropout(0.3)""" + expected_ms_api_name = 'nn.Dropout' + + replaced_code = self.converter_ins.convert_api(code) + assert replaced_code == code.replace('nn.Dropout(0.3)', + "{}(keep_prob=0.7)".format(expected_ms_api_name)) + # test convert_api with torch dot ops def test_convert_api_torch_dot_abs(self): """Test convert_api function work ok when convert api torch.abs""" @@ -202,6 +211,33 @@ class TestConverter: assert replaced_code == code.replace('F.sigmoid(input)', '{}()(input)'.format(expected_ms_api_name)) + def test_convert_api_f_max_pool2d(self): + """Test convert_api function work ok when convert api F.max_pool2d""" + code = """F.max_pool2d(out, 2)""" + expected_ms_api_name = 'P.MaxPool' + + replaced_code = self.converter_ins.convert_api(code) + assert replaced_code == code.replace('F.max_pool2d(out, 2)', + "{}(2, 2, 'valid')(out)".format(expected_ms_api_name)) + + def test_convert_api_f_avg_pool2d_without_strides(self): + """Test convert_api function work ok when convert api F.avg_pool2d""" + code = """F.avg_pool2d(out, 2)""" + expected_ms_api_name = 'P.AvgPool' + + replaced_code = self.converter_ins.convert_api(code) + assert replaced_code == code.replace('F.avg_pool2d(out, 2)', + "{}(2, 2, 'valid')(out)".format(expected_ms_api_name)) + + def test_convert_api_f_avg_pool2d_with_strides(self): + """Test convert_api function work ok when convert api F.avg_pool2d""" + code = """F.avg_pool2d(out, 2, 3)""" + expected_ms_api_name = 'P.AvgPool' + + replaced_code = self.converter_ins.convert_api(code) + assert replaced_code == code.replace('F.avg_pool2d(out, 2, 3)', + "{}(2, 3, 'valid')(out)".format(expected_ms_api_name)) + # test convert_api with tensor dot ops def test_convert_api_tensor_dot_repeat(self): """Test convert_api function work ok when convert api .repeat""" @@ -216,7 +252,6 @@ class TestConverter: """Test convert_api function work ok when convert api .permute""" code = "x.permute(2, 0, 1)" expected_ms_api_name = 'P.Transpose' - replaced_code = self.converter_ins.convert_api(code) assert replaced_code == code.replace('x.permute(2, 0, 1)', '{}()(x, (2, 0, 1,))'.format(expected_ms_api_name)) -- GitLab