提交 f4ca687d 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!277 correct mapping relationship

Merge pull request !277 from quyongxiu1/br_fix_mapping
...@@ -71,9 +71,9 @@ class APIPt: ...@@ -71,9 +71,9 @@ class APIPt:
or the given args_str not valid. or the given args_str not valid.
""" """
# expr is REQUIRED to meet (**) format # expr is REQUIRED to meet (**) format
if not (len(args_str) >= 2 and args_str[0] == "(" and args_str[-1] == ")"): if not (len(args_str) >= 2 and args_str[0] == "(" and args_str.strip()[-1] == ")"):
raise ValueError('[{}] is think as args str, it should start with "(" and end with ")"'.format(args_str)) raise ValueError('"{}" is think as args string, it should start with "(" and end with ")" without '
'considering spaces'.format(args_str))
try: try:
ast_node = ast.parse("whatever_call_name" + args_str) ast_node = ast.parse("whatever_call_name" + args_str)
call_node = ast_node.body[0].value call_node = ast_node.body[0].value
......
...@@ -35,7 +35,14 @@ def gen_explicit_map_f_max_pool2d(params_pt, args_pt): ...@@ -35,7 +35,14 @@ def gen_explicit_map_f_max_pool2d(params_pt, args_pt):
padding = "'valid'" padding = "'valid'"
else: else:
padding = "'same'" padding = "'same'"
return {"padding": padding}
if 'stride' in args_pt:
strides = args_pt['stride']
else:
strides = args_pt['kernel_size']
return {"padding": padding,
"strides": strides}
def gen_explicit_map_nn_sequential(_, args_pt): def gen_explicit_map_nn_sequential(_, args_pt):
...@@ -97,7 +104,14 @@ def gen_explicit_map_nn_maxpool2d(params_pt, args_pt): ...@@ -97,7 +104,14 @@ def gen_explicit_map_nn_maxpool2d(params_pt, args_pt):
pad_mode = "'valid'" pad_mode = "'valid'"
else: else:
pad_mode = "'same'" pad_mode = "'same'"
return {"pad_mode": pad_mode}
if 'stride' in args_pt:
stride = args_pt['stride']
else:
stride = args_pt['kernel_size']
return {"pad_mode": pad_mode,
"stride": stride}
def torch_dot_eye_gen_explicit_map(_, args_pt): def torch_dot_eye_gen_explicit_map(_, args_pt):
......
...@@ -21,14 +21,13 @@ ...@@ -21,14 +21,13 @@
"kernel_size": "REQUIRED", "kernel_size": "REQUIRED",
"stride": null, "stride": null,
"padding": 0, "padding": 0,
"dilation": 1,
"ceil_mode": false, "ceil_mode": false,
"return_indices": false "count_include_pad": true,
"divisor_override": null
} }
], ],
"ms2pt_mapping": { "ms2pt_mapping": {
"ksize": "kernel_size", "ksize": "kernel_size",
"strides": "stride",
"input": "input" "input": "input"
}, },
"gen_explicit_map": "gen_explicit_map_f_max_pool2d" "gen_explicit_map": "gen_explicit_map_f_max_pool2d"
...@@ -62,7 +61,6 @@ ...@@ -62,7 +61,6 @@
], ],
"ms2pt_mapping": { "ms2pt_mapping": {
"ksize": "kernel_size", "ksize": "kernel_size",
"strides": "stride",
"input": "input" "input": "input"
}, },
"gen_explicit_map": "gen_explicit_map_f_max_pool2d" "gen_explicit_map": "gen_explicit_map_f_max_pool2d"
......
...@@ -16,9 +16,7 @@ ...@@ -16,9 +16,7 @@
"inplace": false "inplace": false
} }
], ],
"ms2pt_mapping": { "ms2pt_mapping": {},
"keep_prob": "p"
},
"gen_explicit_map": "nn_dropout_gen_explicit_map" "gen_explicit_map": "nn_dropout_gen_explicit_map"
}, },
"nn.AvgPool2d": { "nn.AvgPool2d": {
...@@ -36,14 +34,13 @@ ...@@ -36,14 +34,13 @@
"kernel_size": "REQUIRED", "kernel_size": "REQUIRED",
"stride": null, "stride": null,
"padding": 0, "padding": 0,
"dilation": 1, "ceil_mode": false,
"return_indices": false, "count_include_pad": true,
"ceil_mode": "False" "divisor_override": null
} }
], ],
"ms2pt_mapping": { "ms2pt_mapping": {
"kernel_size": "kernel_size", "kernel_size": "kernel_size"
"stride": "stride"
}, },
"gen_explicit_map": "gen_explicit_map_nn_maxpool2d" "gen_explicit_map": "gen_explicit_map_nn_maxpool2d"
}, },
...@@ -68,8 +65,7 @@ ...@@ -68,8 +65,7 @@
} }
], ],
"ms2pt_mapping": { "ms2pt_mapping": {
"kernel_size": "kernel_size", "kernel_size": "kernel_size"
"stride": "stride"
}, },
"gen_explicit_map": "gen_explicit_map_nn_maxpool2d" "gen_explicit_map": "gen_explicit_map_nn_maxpool2d"
}, },
......
...@@ -64,6 +64,15 @@ class TestConverter: ...@@ -64,6 +64,15 @@ class TestConverter:
assert replaced_code == code.replace('nn.Softmax(dim=1)', assert replaced_code == code.replace('nn.Softmax(dim=1)',
'{}(axis=1)'.format(expected_ms_api_name)) '{}(axis=1)'.format(expected_ms_api_name))
def test_convert_api_nn_dropout(self):
"""Test convert_api function work ok when convert api nn.Dropout"""
code = """nn.Dropout(0.3)"""
expected_ms_api_name = 'nn.Dropout'
replaced_code = self.converter_ins.convert_api(code)
assert replaced_code == code.replace('nn.Dropout(0.3)',
"{}(keep_prob=0.7)".format(expected_ms_api_name))
# test convert_api with torch dot ops # test convert_api with torch dot ops
def test_convert_api_torch_dot_abs(self): def test_convert_api_torch_dot_abs(self):
"""Test convert_api function work ok when convert api torch.abs""" """Test convert_api function work ok when convert api torch.abs"""
...@@ -202,6 +211,33 @@ class TestConverter: ...@@ -202,6 +211,33 @@ class TestConverter:
assert replaced_code == code.replace('F.sigmoid(input)', assert replaced_code == code.replace('F.sigmoid(input)',
'{}()(input)'.format(expected_ms_api_name)) '{}()(input)'.format(expected_ms_api_name))
def test_convert_api_f_max_pool2d(self):
"""Test convert_api function work ok when convert api F.max_pool2d"""
code = """F.max_pool2d(out, 2)"""
expected_ms_api_name = 'P.MaxPool'
replaced_code = self.converter_ins.convert_api(code)
assert replaced_code == code.replace('F.max_pool2d(out, 2)',
"{}(2, 2, 'valid')(out)".format(expected_ms_api_name))
def test_convert_api_f_avg_pool2d_without_strides(self):
"""Test convert_api function work ok when convert api F.avg_pool2d"""
code = """F.avg_pool2d(out, 2)"""
expected_ms_api_name = 'P.AvgPool'
replaced_code = self.converter_ins.convert_api(code)
assert replaced_code == code.replace('F.avg_pool2d(out, 2)',
"{}(2, 2, 'valid')(out)".format(expected_ms_api_name))
def test_convert_api_f_avg_pool2d_with_strides(self):
"""Test convert_api function work ok when convert api F.avg_pool2d"""
code = """F.avg_pool2d(out, 2, 3)"""
expected_ms_api_name = 'P.AvgPool'
replaced_code = self.converter_ins.convert_api(code)
assert replaced_code == code.replace('F.avg_pool2d(out, 2, 3)',
"{}(2, 3, 'valid')(out)".format(expected_ms_api_name))
# test convert_api with tensor dot ops # test convert_api with tensor dot ops
def test_convert_api_tensor_dot_repeat(self): def test_convert_api_tensor_dot_repeat(self):
"""Test convert_api function work ok when convert api .repeat""" """Test convert_api function work ok when convert api .repeat"""
...@@ -216,7 +252,6 @@ class TestConverter: ...@@ -216,7 +252,6 @@ class TestConverter:
"""Test convert_api function work ok when convert api .permute""" """Test convert_api function work ok when convert api .permute"""
code = "x.permute(2, 0, 1)" code = "x.permute(2, 0, 1)"
expected_ms_api_name = 'P.Transpose' expected_ms_api_name = 'P.Transpose'
replaced_code = self.converter_ins.convert_api(code) replaced_code = self.converter_ins.convert_api(code)
assert replaced_code == code.replace('x.permute(2, 0, 1)', assert replaced_code == code.replace('x.permute(2, 0, 1)',
'{}()(x, (2, 0, 1,))'.format(expected_ms_api_name)) '{}()(x, (2, 0, 1,))'.format(expected_ms_api_name))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册