提交 f4ca687d 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!277 correct mapping relationship

Merge pull request !277 from quyongxiu1/br_fix_mapping
......@@ -71,9 +71,9 @@ class APIPt:
or the given args_str not valid.
"""
# expr is REQUIRED to meet (**) format
if not (len(args_str) >= 2 and args_str[0] == "(" and args_str[-1] == ")"):
raise ValueError('[{}] is think as args str, it should start with "(" and end with ")"'.format(args_str))
if not (len(args_str) >= 2 and args_str[0] == "(" and args_str.strip()[-1] == ")"):
raise ValueError('"{}" is think as args string, it should start with "(" and end with ")" without '
'considering spaces'.format(args_str))
try:
ast_node = ast.parse("whatever_call_name" + args_str)
call_node = ast_node.body[0].value
......
......@@ -35,7 +35,14 @@ def gen_explicit_map_f_max_pool2d(params_pt, args_pt):
padding = "'valid'"
else:
padding = "'same'"
return {"padding": padding}
if 'stride' in args_pt:
strides = args_pt['stride']
else:
strides = args_pt['kernel_size']
return {"padding": padding,
"strides": strides}
def gen_explicit_map_nn_sequential(_, args_pt):
......@@ -97,7 +104,14 @@ def gen_explicit_map_nn_maxpool2d(params_pt, args_pt):
pad_mode = "'valid'"
else:
pad_mode = "'same'"
return {"pad_mode": pad_mode}
if 'stride' in args_pt:
stride = args_pt['stride']
else:
stride = args_pt['kernel_size']
return {"pad_mode": pad_mode,
"stride": stride}
def torch_dot_eye_gen_explicit_map(_, args_pt):
......
......@@ -21,14 +21,13 @@
"kernel_size": "REQUIRED",
"stride": null,
"padding": 0,
"dilation": 1,
"ceil_mode": false,
"return_indices": false
"count_include_pad": true,
"divisor_override": null
}
],
"ms2pt_mapping": {
"ksize": "kernel_size",
"strides": "stride",
"input": "input"
},
"gen_explicit_map": "gen_explicit_map_f_max_pool2d"
......@@ -62,7 +61,6 @@
],
"ms2pt_mapping": {
"ksize": "kernel_size",
"strides": "stride",
"input": "input"
},
"gen_explicit_map": "gen_explicit_map_f_max_pool2d"
......
......@@ -16,9 +16,7 @@
"inplace": false
}
],
"ms2pt_mapping": {
"keep_prob": "p"
},
"ms2pt_mapping": {},
"gen_explicit_map": "nn_dropout_gen_explicit_map"
},
"nn.AvgPool2d": {
......@@ -36,14 +34,13 @@
"kernel_size": "REQUIRED",
"stride": null,
"padding": 0,
"dilation": 1,
"return_indices": false,
"ceil_mode": "False"
"ceil_mode": false,
"count_include_pad": true,
"divisor_override": null
}
],
"ms2pt_mapping": {
"kernel_size": "kernel_size",
"stride": "stride"
"kernel_size": "kernel_size"
},
"gen_explicit_map": "gen_explicit_map_nn_maxpool2d"
},
......@@ -68,8 +65,7 @@
}
],
"ms2pt_mapping": {
"kernel_size": "kernel_size",
"stride": "stride"
"kernel_size": "kernel_size"
},
"gen_explicit_map": "gen_explicit_map_nn_maxpool2d"
},
......
......@@ -64,6 +64,15 @@ class TestConverter:
assert replaced_code == code.replace('nn.Softmax(dim=1)',
'{}(axis=1)'.format(expected_ms_api_name))
def test_convert_api_nn_dropout(self):
"""Test convert_api function work ok when convert api nn.Dropout"""
code = """nn.Dropout(0.3)"""
expected_ms_api_name = 'nn.Dropout'
replaced_code = self.converter_ins.convert_api(code)
assert replaced_code == code.replace('nn.Dropout(0.3)',
"{}(keep_prob=0.7)".format(expected_ms_api_name))
# test convert_api with torch dot ops
def test_convert_api_torch_dot_abs(self):
"""Test convert_api function work ok when convert api torch.abs"""
......@@ -202,6 +211,33 @@ class TestConverter:
assert replaced_code == code.replace('F.sigmoid(input)',
'{}()(input)'.format(expected_ms_api_name))
def test_convert_api_f_max_pool2d(self):
"""Test convert_api function work ok when convert api F.max_pool2d"""
code = """F.max_pool2d(out, 2)"""
expected_ms_api_name = 'P.MaxPool'
replaced_code = self.converter_ins.convert_api(code)
assert replaced_code == code.replace('F.max_pool2d(out, 2)',
"{}(2, 2, 'valid')(out)".format(expected_ms_api_name))
def test_convert_api_f_avg_pool2d_without_strides(self):
"""Test convert_api function work ok when convert api F.avg_pool2d"""
code = """F.avg_pool2d(out, 2)"""
expected_ms_api_name = 'P.AvgPool'
replaced_code = self.converter_ins.convert_api(code)
assert replaced_code == code.replace('F.avg_pool2d(out, 2)',
"{}(2, 2, 'valid')(out)".format(expected_ms_api_name))
def test_convert_api_f_avg_pool2d_with_strides(self):
"""Test convert_api function work ok when convert api F.avg_pool2d"""
code = """F.avg_pool2d(out, 2, 3)"""
expected_ms_api_name = 'P.AvgPool'
replaced_code = self.converter_ins.convert_api(code)
assert replaced_code == code.replace('F.avg_pool2d(out, 2, 3)',
"{}(2, 3, 'valid')(out)".format(expected_ms_api_name))
# test convert_api with tensor dot ops
def test_convert_api_tensor_dot_repeat(self):
"""Test convert_api function work ok when convert api .repeat"""
......@@ -216,7 +252,6 @@ class TestConverter:
"""Test convert_api function work ok when convert api .permute"""
code = "x.permute(2, 0, 1)"
expected_ms_api_name = 'P.Transpose'
replaced_code = self.converter_ins.convert_api(code)
assert replaced_code == code.replace('x.permute(2, 0, 1)',
'{}()(x, (2, 0, 1,))'.format(expected_ms_api_name))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册