Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MindSpore
mindinsight
提交
16fcb034
M
mindinsight
项目概览
MindSpore
/
mindinsight
通知
7
Star
3
Fork
2
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindinsight
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
16fcb034
编写于
5月 28, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
5月 28, 2020
浏览文件
操作
浏览文件
下载
差异文件
!187 more clear report on mindconverter
Merge pull request !187 from quyongxiu1/mindconvert_report_fix_0.3
上级
34a9026e
5e49d578
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
103 addition
and
37 deletion
+103
-37
mindinsight/mindconverter/converter.py
mindinsight/mindconverter/converter.py
+103
-37
未找到文件。
mindinsight/mindconverter/converter.py
浏览文件 @
16fcb034
...
...
@@ -28,6 +28,8 @@ from mindinsight.mindconverter.config import ALL_UNSUPPORTED
from
mindinsight.mindconverter.common.log
import
logger
from
mindinsight.mindconverter.forward_call
import
ForwardCall
LINE_NO_INDEX_DIFF
=
1
class
Converter
:
"""Convert class"""
...
...
@@ -197,6 +199,7 @@ class Converter:
raise
ValueError
(
'"(" not found, {} should work with "("'
.
format
(
call_name
))
right
=
self
.
find_right_parentheses
(
code
,
left
)
end
=
right
expr
=
code
[
start
:
end
+
1
]
args_str
=
code
[
left
:
right
+
1
]
...
...
@@ -336,6 +339,96 @@ class Converter:
mapping
.
update
(
convert_fun
(
*
args
))
return
mapping
@
staticmethod
def
get_code_start_line_num
(
source_lines
):
"""
Get the start code line number exclude comments.
Args:
source_lines (list[str]): Split results of original code.
Returns:
int, the start line number.
"""
stack
=
[]
index
=
0
for
i
,
line
in
enumerate
(
source_lines
):
if
line
.
strip
().
startswith
(
'#'
):
continue
if
line
.
strip
().
startswith
(
'"""'
):
if
not
line
.
endswith
(
'"""
\n
'
):
stack
.
append
(
'"""'
)
continue
if
line
.
strip
().
startswith
(
"'''"
):
if
not
line
.
endswith
(
"'''
\n
"
):
stack
.
append
(
"'''"
)
continue
if
line
.
endswith
(
'"""
\n
'
)
or
line
.
endswith
(
"'''
\n
"
):
stack
.
pop
()
continue
if
line
.
strip
()
!=
''
and
not
stack
:
index
=
i
break
return
index
def
update_code_and_convert_info
(
self
,
code
,
mapping
):
"""
Replace code according to mapping, and update convert info.
Args:
code (str): The code to replace.
mapping (dict): Mapping for original code and the replaced code.
Returns:
str, the replaced code.
"""
for
key
,
value
in
mapping
.
items
():
code
=
code
.
replace
(
key
,
value
)
source_lines
=
code
.
splitlines
(
keepends
=
True
)
start_line_number
=
self
.
get_code_start_line_num
(
source_lines
)
add_import_infos
=
[
'import mindspore
\n
'
,
'import mindspore.nn as nn
\n
'
,
'import mindspore.ops.operations as P
\n
'
]
for
i
,
add_import_info
in
enumerate
(
add_import_infos
):
source_lines
.
insert
(
start_line_number
+
i
,
add_import_info
)
self
.
convert_info
+=
'[Add Import] {}.
\n
'
.
format
(
add_import_info
.
strip
())
insert_count
=
len
(
add_import_infos
)
line_diff
=
insert_count
-
LINE_NO_INDEX_DIFF
for
i
in
range
(
start_line_number
+
insert_count
,
len
(
source_lines
)):
line
=
source_lines
[
i
]
if
(
line
.
startswith
(
'from torch'
)
and
'import'
in
line
)
or
line
.
startswith
(
'import torch'
):
new_line
=
'# '
+
line
source_lines
[
i
]
=
new_line
self
.
convert_info
+=
'[Annotate][Line{:3d}] {} is annotated.
\n
'
.
format
(
i
-
line_diff
,
line
.
strip
())
if
line
.
strip
().
startswith
(
'class'
)
and
'(nn.Module)'
in
line
:
new_line
=
line
.
replace
(
'nn.Module'
,
'nn.Cell'
)
source_lines
[
i
]
=
new_line
self
.
convert_info
+=
'[Convert][Line{:3d}] nn.Module is converted.
\n
'
.
format
(
i
-
line_diff
)
if
line
.
strip
().
startswith
(
'def forward('
):
new_line
=
line
.
replace
(
'forward'
,
'construct'
)
source_lines
[
i
]
=
new_line
self
.
convert_info
+=
'[Convert][Line{:3d}] forward is converted.
\n
'
.
format
(
i
-
line_diff
)
if
'nn.Linear'
in
line
:
new_line
=
line
.
replace
(
'nn.Linear'
,
'nn.Dense'
)
source_lines
[
i
]
=
new_line
self
.
convert_info
+=
'[Convert][Line{:3d}] nn.Linear is converted.
\n
'
.
format
(
i
-
line_diff
)
if
'(nn.Sequential)'
in
line
:
new_line
=
line
.
replace
(
'nn.Sequential'
,
'nn.SequentialCell'
)
source_lines
[
i
]
=
new_line
self
.
convert_info
+=
'[Convert][Line{:3d}] nn.Sequential is converted.
\n
'
.
format
(
i
-
line_diff
)
if
'nn.init.'
in
line
:
new_line
=
line
.
replace
(
'nn.init'
,
'pass # nn.init'
)
source_lines
[
i
]
=
new_line
self
.
convert_info
+=
'[Annotate][Line{:3d}] {} is annotated.
\n
'
.
format
(
i
-
line_diff
,
'nn.init'
)
code
=
''
.
join
(
source_lines
)
return
code
def
convert
(
self
,
import_name
,
output_dir
,
report_dir
):
"""
Convert a module's code, code converted will be save in output_dir, and a report will be save in report_dir.
...
...
@@ -346,10 +439,10 @@ class Converter:
report_dir (str): The path to save report file.
"""
logger
.
info
(
"Start converting %s"
,
import_name
)
self
.
convert_info
+=
'[Start Convert]
\n
The module is {}
\n
'
.
format
(
import_name
)
start_info
=
'[Start Convert]
\n
'
module_info
=
'The module is {}.
\n
'
.
format
(
import_name
)
import_mod
=
importlib
.
import_module
(
import_name
)
srcfile
=
inspect
.
getsourcefile
(
import_mod
)
logger
.
info
(
"Script file is %s"
,
srcfile
)
...
...
@@ -358,40 +451,14 @@ class Converter:
# replace python function under nn.Module
mapping
=
self
.
get_mapping
(
import_mod
,
forward_list
)
code
=
inspect
.
getsource
(
import_mod
)
for
key
,
value
in
mapping
.
items
():
code
=
code
.
replace
(
key
,
value
)
code
=
'import mindspore.ops.operations as P
\n
'
+
code
code
=
'import mindspore.nn as nn
\n
'
+
code
code
=
'import mindspore
\n
'
+
code
self
.
convert_info
+=
'||[Import Add] Add follow import sentences:
\n
'
self
.
convert_info
+=
'import mindspore.ops.operations as P
\n
'
self
.
convert_info
+=
'import mindspore.nn as nn
\n
'
self
.
convert_info
+=
'import mindspore
\n\n
'
code
=
code
.
replace
(
'import torch'
,
'# import torch'
)
code
=
code
.
replace
(
'from torch'
,
'# from torch'
)
code
=
code
.
replace
(
'(nn.Module):'
,
'(nn.Cell):'
)
code
=
code
.
replace
(
'forward('
,
'construct('
)
code
=
code
.
replace
(
'nn.Linear'
,
'nn.Dense'
)
code
=
code
.
replace
(
'(nn.Sequential)'
,
'(nn.SequentialCell)'
)
code
=
code
.
replace
(
'nn.init.'
,
'pass # nn.init.'
)
self
.
convert_info
+=
'||[Import Annotated] Annotated follow import sentences:
\n
'
self
.
convert_info
+=
'import sentence on torch as follows are annotated:
\n
'
self
.
convert_info
+=
'import torch
\n
'
self
.
convert_info
+=
'from torch ...
\n
'
self
.
convert_info
+=
'||[Explicit Convert] Module or function are explicitly converted as follows:
\n
'
self
.
convert_info
+=
'[nn.Module] is converted to [nn.Cell]
\n
'
self
.
convert_info
+=
'[forward] is converted to [construct]
\n
'
self
.
convert_info
+=
'[nn.Linear] is converted to [nn.Dense]
\n
'
self
.
convert_info
+=
'[nn.Sequential] is converted to [nn.SequentialCell]
\n
'
self
.
convert_info
+=
'[nn.init] is not converted and annotated
\n
'
self
.
convert_info
+=
'[Convert over]'
code
=
self
.
update_code_and_convert_info
(
code
,
mapping
)
convert_info_split
=
self
.
convert_info
.
splitlines
(
keepends
=
True
)
convert_info_split
=
sorted
(
convert_info_split
)
convert_info_split
.
insert
(
0
,
start_info
)
convert_info_split
.
insert
(
1
,
module_info
)
convert_info_split
.
append
(
'[Convert Over]'
)
self
.
convert_info
=
''
.
join
(
convert_info_split
)
dest_file
=
os
.
path
.
join
(
output_dir
,
os
.
path
.
basename
(
srcfile
))
with
os
.
fdopen
(
os
.
open
(
dest_file
,
self
.
flags
,
self
.
modes
),
'w'
)
as
file
:
...
...
@@ -428,7 +495,6 @@ def _path_split(file):
Returns:
list[str], list of file tail
"""
file_dir
,
name
=
os
.
path
.
split
(
file
)
if
file_dir
:
...
...
@@ -456,6 +522,6 @@ def main(files_config):
module_name
=
'.'
.
join
(
in_file_split
)
convert_ins
.
convert
(
module_name
,
files_config
[
'outfile_dir'
],
files_config
[
'report_dir'
])
in_module
=
files_config
[
'in_module'
]
in_module
=
files_config
.
get
(
'in_module'
)
if
in_module
:
convert_ins
.
convert
(
in_module
,
files_config
[
'outfile_dir'
],
files_config
[
'report_dir'
])
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录