Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MindSpore
mindinsight
提交
f4ca687d
M
mindinsight
项目概览
MindSpore
/
mindinsight
通知
8
Star
4
Fork
2
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindinsight
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
f4ca687d
编写于
6月 18, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
6月 18, 2020
浏览文件
操作
浏览文件
下载
差异文件
!277 correct mapping relationship
Merge pull request !277 from quyongxiu1/br_fix_mapping
上级
a01709e7
9a93920f
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
63 addition
and
20 deletion
+63
-20
mindinsight/mindconverter/config.py
mindinsight/mindconverter/config.py
+3
-3
mindinsight/mindconverter/funcs.py
mindinsight/mindconverter/funcs.py
+16
-2
mindinsight/mindconverter/mappings/f_mappings.json
mindinsight/mindconverter/mappings/f_mappings.json
+2
-4
mindinsight/mindconverter/mappings/nn_mappings.json
mindinsight/mindconverter/mappings/nn_mappings.json
+6
-10
tests/ut/mindconverter/test_converter.py
tests/ut/mindconverter/test_converter.py
+36
-1
未找到文件。
mindinsight/mindconverter/config.py
浏览文件 @
f4ca687d
...
...
@@ -71,9 +71,9 @@ class APIPt:
or the given args_str not valid.
"""
# expr is REQUIRED to meet (**) format
if
not
(
len
(
args_str
)
>=
2
and
args_str
[
0
]
==
"("
and
args_str
[
-
1
]
==
")"
):
raise
ValueError
(
'
[{}] is think as args str, it should start with "(" and end with ")"'
.
format
(
args_str
))
if
not
(
len
(
args_str
)
>=
2
and
args_str
[
0
]
==
"("
and
args_str
.
strip
()
[
-
1
]
==
")"
):
raise
ValueError
(
'
"{}" is think as args string, it should start with "(" and end with ")" without '
'considering spaces'
.
format
(
args_str
))
try
:
ast_node
=
ast
.
parse
(
"whatever_call_name"
+
args_str
)
call_node
=
ast_node
.
body
[
0
].
value
...
...
mindinsight/mindconverter/funcs.py
浏览文件 @
f4ca687d
...
...
@@ -35,7 +35,14 @@ def gen_explicit_map_f_max_pool2d(params_pt, args_pt):
padding
=
"'valid'"
else
:
padding
=
"'same'"
return
{
"padding"
:
padding
}
if
'stride'
in
args_pt
:
strides
=
args_pt
[
'stride'
]
else
:
strides
=
args_pt
[
'kernel_size'
]
return
{
"padding"
:
padding
,
"strides"
:
strides
}
def
gen_explicit_map_nn_sequential
(
_
,
args_pt
):
...
...
@@ -97,7 +104,14 @@ def gen_explicit_map_nn_maxpool2d(params_pt, args_pt):
pad_mode
=
"'valid'"
else
:
pad_mode
=
"'same'"
return
{
"pad_mode"
:
pad_mode
}
if
'stride'
in
args_pt
:
stride
=
args_pt
[
'stride'
]
else
:
stride
=
args_pt
[
'kernel_size'
]
return
{
"pad_mode"
:
pad_mode
,
"stride"
:
stride
}
def
torch_dot_eye_gen_explicit_map
(
_
,
args_pt
):
...
...
mindinsight/mindconverter/mappings/f_mappings.json
浏览文件 @
f4ca687d
...
...
@@ -21,14 +21,13 @@
"kernel_size"
:
"REQUIRED"
,
"stride"
:
null
,
"padding"
:
0
,
"dilation"
:
1
,
"ceil_mode"
:
false
,
"return_indices"
:
false
"count_include_pad"
:
true
,
"divisor_override"
:
null
}
],
"ms2pt_mapping"
:
{
"ksize"
:
"kernel_size"
,
"strides"
:
"stride"
,
"input"
:
"input"
},
"gen_explicit_map"
:
"gen_explicit_map_f_max_pool2d"
...
...
@@ -62,7 +61,6 @@
],
"ms2pt_mapping"
:
{
"ksize"
:
"kernel_size"
,
"strides"
:
"stride"
,
"input"
:
"input"
},
"gen_explicit_map"
:
"gen_explicit_map_f_max_pool2d"
...
...
mindinsight/mindconverter/mappings/nn_mappings.json
浏览文件 @
f4ca687d
...
...
@@ -16,9 +16,7 @@
"inplace"
:
false
}
],
"ms2pt_mapping"
:
{
"keep_prob"
:
"p"
},
"ms2pt_mapping"
:
{},
"gen_explicit_map"
:
"nn_dropout_gen_explicit_map"
},
"nn.AvgPool2d"
:
{
...
...
@@ -36,14 +34,13 @@
"kernel_size"
:
"REQUIRED"
,
"stride"
:
null
,
"padding"
:
0
,
"
dilation"
:
1
,
"
return_indices"
:
fals
e
,
"
ceil_mode"
:
"False"
"
ceil_mode"
:
false
,
"
count_include_pad"
:
tru
e
,
"
divisor_override"
:
null
}
],
"ms2pt_mapping"
:
{
"kernel_size"
:
"kernel_size"
,
"stride"
:
"stride"
"kernel_size"
:
"kernel_size"
},
"gen_explicit_map"
:
"gen_explicit_map_nn_maxpool2d"
},
...
...
@@ -68,8 +65,7 @@
}
],
"ms2pt_mapping"
:
{
"kernel_size"
:
"kernel_size"
,
"stride"
:
"stride"
"kernel_size"
:
"kernel_size"
},
"gen_explicit_map"
:
"gen_explicit_map_nn_maxpool2d"
},
...
...
tests/ut/mindconverter/test_converter.py
浏览文件 @
f4ca687d
...
...
@@ -64,6 +64,15 @@ class TestConverter:
assert
replaced_code
==
code
.
replace
(
'nn.Softmax(dim=1)'
,
'{}(axis=1)'
.
format
(
expected_ms_api_name
))
def
test_convert_api_nn_dropout
(
self
):
"""Test convert_api function work ok when convert api nn.Dropout"""
code
=
"""nn.Dropout(0.3)"""
expected_ms_api_name
=
'nn.Dropout'
replaced_code
=
self
.
converter_ins
.
convert_api
(
code
)
assert
replaced_code
==
code
.
replace
(
'nn.Dropout(0.3)'
,
"{}(keep_prob=0.7)"
.
format
(
expected_ms_api_name
))
# test convert_api with torch dot ops
def
test_convert_api_torch_dot_abs
(
self
):
"""Test convert_api function work ok when convert api torch.abs"""
...
...
@@ -202,6 +211,33 @@ class TestConverter:
assert
replaced_code
==
code
.
replace
(
'F.sigmoid(input)'
,
'{}()(input)'
.
format
(
expected_ms_api_name
))
def
test_convert_api_f_max_pool2d
(
self
):
"""Test convert_api function work ok when convert api F.max_pool2d"""
code
=
"""F.max_pool2d(out, 2)"""
expected_ms_api_name
=
'P.MaxPool'
replaced_code
=
self
.
converter_ins
.
convert_api
(
code
)
assert
replaced_code
==
code
.
replace
(
'F.max_pool2d(out, 2)'
,
"{}(2, 2, 'valid')(out)"
.
format
(
expected_ms_api_name
))
def
test_convert_api_f_avg_pool2d_without_strides
(
self
):
"""Test convert_api function work ok when convert api F.avg_pool2d"""
code
=
"""F.avg_pool2d(out, 2)"""
expected_ms_api_name
=
'P.AvgPool'
replaced_code
=
self
.
converter_ins
.
convert_api
(
code
)
assert
replaced_code
==
code
.
replace
(
'F.avg_pool2d(out, 2)'
,
"{}(2, 2, 'valid')(out)"
.
format
(
expected_ms_api_name
))
def
test_convert_api_f_avg_pool2d_with_strides
(
self
):
"""Test convert_api function work ok when convert api F.avg_pool2d"""
code
=
"""F.avg_pool2d(out, 2, 3)"""
expected_ms_api_name
=
'P.AvgPool'
replaced_code
=
self
.
converter_ins
.
convert_api
(
code
)
assert
replaced_code
==
code
.
replace
(
'F.avg_pool2d(out, 2, 3)'
,
"{}(2, 3, 'valid')(out)"
.
format
(
expected_ms_api_name
))
# test convert_api with tensor dot ops
def
test_convert_api_tensor_dot_repeat
(
self
):
"""Test convert_api function work ok when convert api .repeat"""
...
...
@@ -216,7 +252,6 @@ class TestConverter:
"""Test convert_api function work ok when convert api .permute"""
code
=
"x.permute(2, 0, 1)"
expected_ms_api_name
=
'P.Transpose'
replaced_code
=
self
.
converter_ins
.
convert_api
(
code
)
assert
replaced_code
==
code
.
replace
(
'x.permute(2, 0, 1)'
,
'{}()(x, (2, 0, 1,))'
.
format
(
expected_ms_api_name
))
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录