Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MindSpore
mindinsight
提交
2f2f6a3f
M
mindinsight
项目概览
MindSpore
/
mindinsight
通知
8
Star
4
Fork
2
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindinsight
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
2f2f6a3f
编写于
6月 06, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
6月 06, 2020
浏览文件
操作
浏览文件
下载
差异文件
!245 use json to save op mapping in mindconverter
Merge pull request !245 from quyongxiu1/br_0601_config_optim
上级
3abcb4a2
6adc58f0
变更
11
隐藏空白更改
内联
并排
Showing
11 changed file
with
660 addition
and
312 deletion
+660
-312
mindinsight/mindconverter/config.py
mindinsight/mindconverter/config.py
+62
-290
mindinsight/mindconverter/enums.py
mindinsight/mindconverter/enums.py
+0
-22
mindinsight/mindconverter/funcs.py
mindinsight/mindconverter/funcs.py
+106
-0
mindinsight/mindconverter/mappings/f_mappings.json
mindinsight/mindconverter/mappings/f_mappings.json
+108
-0
mindinsight/mindconverter/mappings/nn_mappings.json
mindinsight/mindconverter/mappings/nn_mappings.json
+220
-0
mindinsight/mindconverter/mappings/tensor_dot_mappings.json
mindinsight/mindconverter/mappings/tensor_dot_mappings.json
+119
-0
mindinsight/mindconverter/mappings/torch_dot_mappings.json
mindinsight/mindconverter/mappings/torch_dot_mappings.json
+45
-0
mindinsight/mindconverter/ops/f_list.json
mindinsight/mindconverter/ops/f_list.json
+0
-0
mindinsight/mindconverter/ops/nn_list.json
mindinsight/mindconverter/ops/nn_list.json
+0
-0
mindinsight/mindconverter/ops/tensor_dot_list.json
mindinsight/mindconverter/ops/tensor_dot_list.json
+0
-0
mindinsight/mindconverter/ops/torch_dot_list.json
mindinsight/mindconverter/ops/torch_dot_list.json
+0
-0
未找到文件。
mindinsight/mindconverter/config.py
浏览文件 @
2f2f6a3f
...
@@ -15,17 +15,18 @@
...
@@ -15,17 +15,18 @@
"""API config"""
"""API config"""
import
ast
import
ast
from
collections
import
OrderedDict
from
collections
import
OrderedDict
from
functools
import
partial
from
importlib
import
import_module
import
json
import
json
import
os
import
os
import
pasta
import
pasta
from
mindinsight.mindconverter.enums
import
RequriedType
from
mindinsight.mindconverter.common.log
import
logger
from
mindinsight.mindconverter.common.log
import
logger
REQUIRED
=
RequriedType
.
REQUIRED
.
name
UNREQUIRED
=
RequriedType
.
UNREQUIRED
.
name
REQUIRED
=
'REQUIRED'
UNREQUIRED
=
'UNREQUIRED'
FUNC_MODULE
=
'mindinsight.mindconverter.funcs'
class
APIPt
:
class
APIPt
:
...
@@ -250,88 +251,65 @@ class MappingHelper:
...
@@ -250,88 +251,65 @@ class MappingHelper:
return
expr_ms
return
expr_ms
def
ge
n_explicit_map_nn_sequential
(
_
,
args_pt
):
def
ge
t_ms_api
(
ms_api_info
):
"""
"""
Ge
nerate explicit_map for nn.Sequential
.
Ge
t APIMs instance from ms_api_info
.
Args:
Args:
args_pt (dict): Args for APIPt.
ms_api_info (list): info for create an APIMs instance, the first value in list is name for APIMs, the second(if
provided) is params for APIMs, the third(if provided) is p_attrs for APIMs.
Returns:
Returns:
dict, map between frames
.
APIMs, instance of APIMs parsed from given info
.
"""
"""
args
=
args_pt
[
'*args'
]
ms_name
=
ms_api_info
[
0
]
return
{
"*args"
:
"[{}]"
.
format
(
args
)}
ms_params
=
ms_api_info
[
1
]
if
len
(
ms_api_info
)
>=
2
else
None
ms_p_attrs
=
set
(
ms_api_info
[
2
])
if
len
(
ms_api_info
)
>=
3
else
None
ms_api
=
APIMs
(
name
=
ms_name
,
params
=
ms_params
,
p_attrs
=
ms_p_attrs
)
return
ms_api
def
ge
n_explicit_map_nn_maxpool2d
(
params_pt
,
args_pt
):
def
ge
t_pt_api
(
pt_api_info
):
"""
"""
Ge
nerate explicit_map for nn.MaxPool2d
.
Ge
t APIPt instance from pt_api_info
.
Args:
Args:
p
arams_pt (dict): Params for APIPt.
p
t_api_info (list): info for create an APIMs instance, the first value in list is name for APIPt, the second(if
args_pt (dict): Arg
s for APIPt.
provided) is param
s for APIPt.
Returns:
Returns:
dict, map between frames.
APIMs, instance of APIMs parsed from given info.
"""
if
'padding'
in
args_pt
:
padding
=
args_pt
[
'padding'
]
else
:
padding
=
params_pt
[
'padding'
]
if
padding
.
strip
()
in
(
"0"
,
"(0,0)"
,
"(0, 0)"
):
pad_mode
=
"'valid'"
else
:
pad_mode
=
"'same'"
return
{
"pad_mode"
:
pad_mode
}
def
gen_explicit_map_f_max_pool2d
(
params_pt
,
args_pt
):
"""
"""
Generate explicit_map for F.MaxPool2d.
pt_name
=
pt_api_info
[
0
]
pt_params
=
pt_api_info
[
1
]
if
len
(
pt_api_info
)
>=
2
else
None
pt_api
=
APIPt
(
name
=
pt_name
,
params
=
pt_params
)
return
pt_api
Args:
params_pt (dict): Params for APIPt.
args_pt (dict): Args for APIPt.
Returns:
def
get_mapping_from_file
(
path
):
dict, map between frames.
"""
"""
if
'padding'
in
args_pt
:
Parse mapping info from given file.
padding
=
args_pt
[
'padding'
]
else
:
padding
=
params_pt
[
'padding'
]
if
padding
.
strip
()
in
(
"0"
,
"(0,0)"
,
"(0, 0)"
):
padding
=
"'valid'"
else
:
padding
=
"'same'"
return
{
"padding"
:
padding
}
def
gen_explicit_map_one_delta
(
params_pt
,
args_pt
,
k_ms
,
k_pt
):
"""
Generate explicit_map for which include mapping relationship is `1 - k_ms = k_pt`.
Args:
Args:
params_pt (dict): Params for APIPt.
path (str): The file path.
args_pt (dict): Args for APIPt.
Returns:
Returns:
dict,
map between frames
.
dict,
key is op name, value is a relevant instance of MappingHelper
.
"""
"""
value
=
args_pt
[
k_pt
]
if
k_pt
in
args_pt
else
params_pt
[
k_pt
]
mapping_info_d
=
load_json_file
(
path
)
value
=
value
.
strip
()
parse_mapping_dict
=
{}
for
key
,
value
in
mapping_info_d
.
items
():
def
is_number
(
string
):
ms_api_info
=
value
.
pop
(
'ms_api'
)
try
:
ms_api
=
get_ms_api
(
ms_api_info
)
float
(
string
)
pt_api_info
=
value
.
pop
(
'pt_api'
)
return
True
pt_api
=
get_pt_api
(
pt_api_info
)
except
ValueError
:
gen_explicit_map
=
value
.
get
(
'gen_explicit_map'
)
return
False
if
gen_explicit_map
:
module_name
=
import_module
(
FUNC_MODULE
)
if
is_number
(
value
):
value
[
'gen_explicit_map'
]
=
getattr
(
module_name
,
gen_explicit_map
)
return
{
k_ms
:
str
(
1
-
float
(
value
))}
return
{
k_ms
:
"1.0 - "
+
value
}
parse_mapping_dict
.
update
({
key
:
MappingHelper
(
**
dict
(
ms_api
=
ms_api
,
pt_api
=
pt_api
),
**
value
)})
return
parse_mapping_dict
def
load_json_file
(
file_path
):
def
load_json_file
(
file_path
):
...
@@ -350,244 +328,38 @@ def load_json_file(file_path):
...
@@ -350,244 +328,38 @@ def load_json_file(file_path):
# ---------------------------- mappings ----------------------------
# ---------------------------- mappings ----------------------------
NN_MAPPING
=
{
NN_MAPPING_PATH
=
os
.
path
.
realpath
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
'mappings/nn_mappings.json'
))
'nn.Sequential'
:
MappingHelper
(
**
{
"ms_api"
:
APIMs
(
'nn.SequentialCell'
,
OrderedDict
([(
'*args'
,
REQUIRED
)])),
NN_MAPPING
=
get_mapping_from_file
(
NN_MAPPING_PATH
)
"pt_api"
:
APIPt
(
'nn.Sequential'
,
OrderedDict
([(
'*args'
,
REQUIRED
)])),
# update to add key with full api_name, which starts with 'torch.nn.'
"gen_explicit_map"
:
gen_explicit_map_nn_sequential
,
"export_key"
:
False
}),
'nn.Conv2d'
:
MappingHelper
(
**
{
"ms_api"
:
APIMs
(
'nn.Conv2d'
,
OrderedDict
(
in_channels
=
REQUIRED
,
out_channels
=
REQUIRED
,
kernel_size
=
REQUIRED
,
stride
=
1
,
pad_mode
=
'same'
,
padding
=
0
,
dilation
=
1
,
group
=
1
,
has_bias
=
False
,
weight_init
=
'normal'
,
bias_init
=
'zeros'
)),
"pt_api"
:
APIPt
(
'nn.Conv2d'
,
OrderedDict
(
in_channels
=
REQUIRED
,
out_channels
=
REQUIRED
,
kernel_size
=
REQUIRED
,
stride
=
1
,
padding
=
0
,
dilation
=
1
,
groups
=
1
,
bias
=
True
,
padding_mode
=
'zeros'
)),
"ms2pt_mapping"
:
{
'in_channels'
:
'in_channels'
,
'out_channels'
:
'out_channels'
,
'kernel_size'
:
'kernel_size'
,
'stride'
:
'stride'
,
'padding'
:
'padding'
,
'dilation'
:
'dilation'
,
'group'
:
'groups'
,
'has_bias'
:
'bias'
},
"gen_explicit_map"
:
(
lambda
params_pt
,
args_pt
:
{
"pad_mode"
:
"'pad'"
})
}),
'nn.BatchNorm2d'
:
MappingHelper
(
**
{
"ms_api"
:
APIMs
(
'nn.BatchNorm2d'
,
OrderedDict
(
num_features
=
REQUIRED
,
eps
=
1e-5
,
momentum
=
0.9
,
affine
=
True
,
gamma_init
=
'ones'
,
beta_init
=
'zeros'
,
moving_mean_init
=
'zeros'
,
moving_var_init
=
'ones'
,
use_batch_statistics
=
True
)),
"pt_api"
:
APIPt
(
'nn.BatchNorm2d'
,
OrderedDict
(
num_features
=
REQUIRED
,
eps
=
1e-5
,
momentum
=
0.1
,
affine
=
True
,
track_running_stats
=
True
)),
"ms2pt_mapping"
:
{
"num_features"
:
"num_features"
,
"eps"
:
"eps"
,
"affine"
:
"affine"
,
"use_batch_statistics"
:
"track_running_stats"
},
"gen_explicit_map"
:
partial
(
gen_explicit_map_one_delta
,
k_ms
=
"momentum"
,
k_pt
=
"momentum"
)
}),
'nn.ReLU'
:
MappingHelper
(
**
{
"ms_api"
:
APIMs
(
'nn.ReLU'
,
OrderedDict
()),
"pt_api"
:
APIPt
(
'nn.ReLU'
,
OrderedDict
(
inplace
=
False
)),
"ms2pt_mapping"
:
{}}),
'nn.ReLU6'
:
MappingHelper
(
**
{
"ms_api"
:
APIMs
(
'nn.ReLU6'
,
OrderedDict
()),
"pt_api"
:
APIPt
(
'nn.ReLU6'
,
OrderedDict
(
inplace
=
False
)),
"ms2pt_mapping"
:
{}}),
'nn.Linear'
:
MappingHelper
(
**
{
"ms_api"
:
APIMs
(
'nn.Dense'
,
OrderedDict
(
in_channels
=
REQUIRED
,
out_channels
=
REQUIRED
,
weight_init
=
'normal'
,
bias_init
=
'zeros'
,
has_bias
=
True
,
activation
=
None
)),
"pt_api"
:
APIPt
(
'nn.Linear'
,
OrderedDict
(
in_features
=
REQUIRED
,
out_features
=
REQUIRED
,
bias
=
True
)),
"ms2pt_mapping"
:
{
"in_channels"
:
"in_features"
,
"out_channels"
:
"out_features"
,
"has_bias"
:
"bias"
}
}),
'nn.MaxPool2d'
:
MappingHelper
(
**
{
"ms_api"
:
APIMs
(
'nn.MaxPool2d'
,
OrderedDict
(
kernel_size
=
1
,
stride
=
1
,
pad_mode
=
"valid"
)),
"pt_api"
:
APIPt
(
'nn.MaxPool2d'
,
OrderedDict
(
kernel_size
=
REQUIRED
,
stride
=
None
,
padding
=
0
,
dilation
=
1
,
return_indices
=
False
,
ceil_mode
=
"False"
)),
"ms2pt_mapping"
:
{
"kernel_size"
:
"kernel_size"
,
"stride"
:
"stride"
},
"gen_explicit_map"
:
gen_explicit_map_nn_maxpool2d
}),
'nn.AvgPool2d'
:
MappingHelper
(
**
{
"ms_api"
:
APIMs
(
'nn.AvgPool2d'
,
OrderedDict
(
kernel_size
=
1
,
stride
=
1
,
pad_mode
=
"valid"
)),
"pt_api"
:
APIPt
(
'nn.AvgPool2d'
,
OrderedDict
(
kernel_size
=
REQUIRED
,
stride
=
None
,
padding
=
0
,
dilation
=
1
,
return_indices
=
False
,
ceil_mode
=
"False"
)),
"ms2pt_mapping"
:
{
"kernel_size"
:
"kernel_size"
,
"stride"
:
"stride"
},
"gen_explicit_map"
:
gen_explicit_map_nn_maxpool2d
}),
'nn.Dropout'
:
MappingHelper
(
**
{
"ms_api"
:
APIMs
(
'nn.Dropout'
,
OrderedDict
(
keep_prob
=
0.5
,
seed0
=
0
,
seed1
=
0
,
dtype
=
"mstype.float32"
)),
"pt_api"
:
APIPt
(
'nn.Dropout'
,
OrderedDict
(
p
=
0.5
,
inplace
=
False
)),
"ms2pt_mapping"
:
{
"keep_prob"
:
"p"
},
"gen_explicit_map"
:
partial
(
gen_explicit_map_one_delta
,
k_ms
=
"keep_prob"
,
k_pt
=
"p"
)
})
}
NN_MAPPING
.
update
({
"torch."
+
k
:
v
for
k
,
v
in
NN_MAPPING
.
items
()})
NN_MAPPING
.
update
({
"torch."
+
k
:
v
for
k
,
v
in
NN_MAPPING
.
items
()})
F_MAPPING_PATH
=
os
.
path
.
realpath
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
'mappings/f_mappings.json'
))
F_MAPPING
=
get_mapping_from_file
(
F_MAPPING_PATH
)
# update to add key starts with 'nn.functional.'
NN_FUNCTIONAL_D
=
{
"nn.functional."
+
k
[
len
(
'F.'
):]:
v
for
k
,
v
in
F_MAPPING
.
items
()}
# update to add key starts with 'torch.nn.functiona.l'
TORCH_NN_FUNCTIONAL_D
=
{
"torch.nn.functional."
+
k
[
len
(
'F.'
):]:
v
for
k
,
v
in
F_MAPPING
.
items
()}
F_MAPPING
.
update
(
NN_FUNCTIONAL_D
)
F_MAPPING
.
update
(
TORCH_NN_FUNCTIONAL_D
)
F_MAPPING
=
{
TORCH_DOT_MAPPING_PATH
=
os
.
path
.
realpath
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
'mappings/torch_dot_mappings.json'
))
'F.relu'
:
MappingHelper
(
**
{
"ms_api"
:
APIMs
(
'P.ReLU'
,
OrderedDict
(
input
=
REQUIRED
)),
TORCH_DOT_MAPPING
=
get_mapping_from_file
(
TORCH_DOT_MAPPING_PATH
)
"pt_api"
:
APIPt
(
'F.relu'
,
OrderedDict
(
input
=
REQUIRED
,
inplace
=
False
)),
"ms2pt_mapping"
:
{
"input"
:
"input"
},
}),
'F.relu6'
:
MappingHelper
(
**
{
"ms_api"
:
APIMs
(
'P.ReLU6'
,
OrderedDict
(
input
=
REQUIRED
)),
"pt_api"
:
APIPt
(
'F.relu6'
,
OrderedDict
(
input
=
REQUIRED
,
inplace
=
False
)),
"ms2pt_mapping"
:
{
"input"
:
"input"
},
}),
'F.max_pool2d'
:
MappingHelper
(
**
{
"ms_api"
:
APIMs
(
'P.MaxPool'
,
OrderedDict
(
ksize
=
1
,
strides
=
1
,
padding
=
"valid"
,
input
=
REQUIRED
),
p_attrs
=
{
"ksize"
,
"strides"
,
"padding"
}),
"pt_api"
:
APIPt
(
'F.max_pool2d'
,
OrderedDict
(
input
=
REQUIRED
,
kernel_size
=
REQUIRED
,
stride
=
None
,
padding
=
0
,
dilation
=
1
,
ceil_mode
=
False
,
return_indices
=
False
)),
"ms2pt_mapping"
:
{
"ksize"
:
"kernel_size"
,
"strides"
:
"stride"
,
"input"
:
"input"
,
},
"gen_explicit_map"
:
gen_explicit_map_f_max_pool2d
}),
'F.avg_pool2d'
:
MappingHelper
(
**
{
"ms_api"
:
APIMs
(
'P.AvgPool'
,
OrderedDict
(
ksize
=
1
,
strides
=
1
,
padding
=
"valid"
,
input
=
REQUIRED
),
p_attrs
=
{
"ksize"
,
"strides"
,
"padding"
}),
"pt_api"
:
APIPt
(
'F.avg_pool2d'
,
OrderedDict
(
input
=
REQUIRED
,
kernel_size
=
REQUIRED
,
stride
=
None
,
padding
=
0
,
dilation
=
1
,
ceil_mode
=
False
,
return_indices
=
False
)),
"ms2pt_mapping"
:
{
"ksize"
:
"kernel_size"
,
"strides"
:
"stride"
,
"input"
:
"input"
,
},
"gen_explicit_map"
:
gen_explicit_map_f_max_pool2d
}),
}
nn_functional_d
=
{
"nn.functional."
+
k
[
2
:]:
v
for
k
,
v
in
F_MAPPING
.
items
()}
torch_nn_functional_d
=
{
"torch.nn.functional."
+
k
[
2
:]:
v
for
k
,
v
in
F_MAPPING
.
items
()}
F_MAPPING
.
update
(
nn_functional_d
)
F_MAPPING
.
update
(
torch_nn_functional_d
)
TORCH_DOT_MAPPING
=
{
'torch.flatten'
:
MappingHelper
(
**
{
"ms_api"
:
APIMs
(
'P.Flatten'
,
OrderedDict
(
input
=
REQUIRED
)),
"pt_api"
:
APIPt
(
'torch.flatten'
,
OrderedDict
(
input
=
REQUIRED
,
start_dim
=
0
,
end_dim
=-
1
)),
"ms2pt_mapping"
:
{
"input"
:
"input"
}
}),
'torch.cat'
:
MappingHelper
(
**
{
"ms_api"
:
APIMs
(
'P.Concat'
,
OrderedDict
(
axis
=
0
,
input
=
REQUIRED
),
p_attrs
=
{
"axis"
}),
"pt_api"
:
APIPt
(
'torch.flatten'
,
OrderedDict
(
tensors
=
REQUIRED
,
dim
=
0
,
out
=
None
)),
"ms2pt_mapping"
:
{
"input"
:
"tensors"
,
"axis"
:
"dim"
}
}),
}
TENSOR_DOT_MAPPING
=
{
'.view'
:
MappingHelper
(
**
{
"ms_api"
:
APIMs
(
'P.Reshape'
,
OrderedDict
(
x
=
REQUIRED
,
shape
=
REQUIRED
)),
"pt_api"
:
APIPt
(
'.view'
,
OrderedDict
([(
'*shape'
,
REQUIRED
)])),
"ms2pt_mapping"
:
{
"x"
:
"call_name"
},
"gen_explicit_map"
:
(
lambda
params_pt
,
args_pt
:
{
"shape"
:
"("
+
args_pt
[
"*shape"
]
+
",)"
})
}),
'.size'
:
MappingHelper
(
**
{
"ms_api"
:
APIMs
(
'P.Shape'
,
OrderedDict
(
x
=
REQUIRED
)),
"pt_api"
:
APIPt
(
'.size'
,
OrderedDict
([(
'idx'
,
REQUIRED
)])),
"ms2pt_mapping"
:
{
"x"
:
"call_name"
}
}),
'.flatten'
:
MappingHelper
(
**
{
"ms_api"
:
APIMs
(
'P.Flatten'
,
OrderedDict
(
input
=
REQUIRED
)),
"pt_api"
:
APIPt
(
'.flatten'
,
OrderedDict
(
start_dim
=
0
,
end_dim
=-
1
)),
"ms2pt_mapping"
:
{
"input"
:
"call_name"
}
}),
'.reshape'
:
MappingHelper
(
**
{
"ms_api"
:
APIMs
(
'P.Reshape'
,
OrderedDict
(
x
=
REQUIRED
,
shape
=
REQUIRED
)),
"pt_api"
:
APIPt
(
'.reshape'
,
OrderedDict
([(
'*shape'
,
REQUIRED
)])),
"ms2pt_mapping"
:
{
"x"
:
"call_name"
},
"gen_explicit_map"
:
(
lambda
params_pt
,
args_pt
:
{
"shape"
:
"("
+
args_pt
[
"*shape"
]
+
",)"
})
}),
'.mean'
:
MappingHelper
(
**
{
"ms_api"
:
APIMs
(
'P.ReduceMean'
,
OrderedDict
(
keep_dims
=
False
,
input
=
REQUIRED
,
axis
=
()),
p_attrs
=
{
"keep_dims"
}),
"pt_api"
:
APIPt
(
'.mean'
,
OrderedDict
(
dim
=
None
,
keepdim
=
False
)),
"ms2pt_mapping"
:
{
"keep_dims"
:
"keepdim"
,
"axis"
:
"dim"
,
"input"
:
"call_name"
},
}),
'.squeeze'
:
MappingHelper
(
**
{
"ms_api"
:
APIMs
(
'P.ReduceMean'
,
OrderedDict
(
input
=
REQUIRED
,
axis
=
()),
p_attrs
=
{
"axis"
}),
"pt_api"
:
APIPt
(
'.squeeze'
,
OrderedDict
(
dim
=
None
)),
"ms2pt_mapping"
:
{
"axis"
:
"dim"
,
"input"
:
"call_name"
},
}),
}
TENSOR_DOT_MAPPING_PATH
=
os
.
path
.
realpath
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
'mappings/tensor_dot_mappings.json'
))
TENSOR_DOT_MAPPING
=
get_mapping_from_file
(
TENSOR_DOT_MAPPING_PATH
)
ALL_MAPPING
=
{
**
NN_MAPPING
,
**
F_MAPPING
,
**
TORCH_DOT_MAPPING
,
**
TENSOR_DOT_MAPPING
}
ALL_MAPPING
=
{
**
NN_MAPPING
,
**
F_MAPPING
,
**
TORCH_DOT_MAPPING
,
**
TENSOR_DOT_MAPPING
}
# ---------------------------- api list support or not support ----------------------------
# ---------------------------- api list support or not support ----------------------------
NN_LIST_PATH
=
os
.
path
.
realpath
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
'nn_list.json'
))
NN_LIST_PATH
=
os
.
path
.
realpath
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
'
ops'
,
'
nn_list.json'
))
NN_LIST
=
load_json_file
(
NN_LIST_PATH
)
NN_LIST
=
load_json_file
(
NN_LIST_PATH
)
NN_LIST
+=
[
"torch."
+
name
for
name
in
NN_LIST
]
NN_LIST
+=
[
"torch."
+
name
for
name
in
NN_LIST
]
NN_SUPPORTED
=
[
x
for
x
in
NN_LIST
if
x
in
ALL_MAPPING
]
NN_SUPPORTED
=
[
x
for
x
in
NN_LIST
if
x
in
ALL_MAPPING
]
NN_UNSUPPORTED
=
[
x
for
x
in
NN_LIST
if
x
not
in
ALL_MAPPING
]
NN_UNSUPPORTED
=
[
x
for
x
in
NN_LIST
if
x
not
in
ALL_MAPPING
]
F_LIST_PATH
=
os
.
path
.
realpath
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
'f_list.json'
))
F_LIST_PATH
=
os
.
path
.
realpath
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
'
ops'
,
'
f_list.json'
))
F_LIST
=
load_json_file
(
F_LIST_PATH
)
F_LIST
=
load_json_file
(
F_LIST_PATH
)
F_LIST
+=
[
"F."
+
name
[
len
(
"torch.nn.functional."
):]
for
name
in
F_LIST
]
+
\
F_LIST
+=
[
"F."
+
name
[
len
(
"torch.nn.functional."
):]
for
name
in
F_LIST
]
+
\
[
name
[
len
(
"torch."
):]
for
name
in
F_LIST
]
[
name
[
len
(
"torch."
):]
for
name
in
F_LIST
]
...
@@ -595,7 +367,7 @@ F_SUPPORTED = [x for x in F_LIST if x in ALL_MAPPING]
...
@@ -595,7 +367,7 @@ F_SUPPORTED = [x for x in F_LIST if x in ALL_MAPPING]
F_UNSUPPORTED
=
[
x
for
x
in
F_LIST
if
x
not
in
ALL_MAPPING
]
F_UNSUPPORTED
=
[
x
for
x
in
F_LIST
if
x
not
in
ALL_MAPPING
]
TORCH_DOT_LIST_PATH
=
os
.
path
.
realpath
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
'torch_dot_list.json'
))
TORCH_DOT_LIST_PATH
=
os
.
path
.
realpath
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
'
ops'
,
'
torch_dot_list.json'
))
TORCH_DOT_LIST
=
load_json_file
(
TORCH_DOT_LIST_PATH
)
TORCH_DOT_LIST
=
load_json_file
(
TORCH_DOT_LIST_PATH
)
...
@@ -603,7 +375,7 @@ TORCH_DOT_SUPPORTED = [x for x in TORCH_DOT_LIST if x in ALL_MAPPING]
...
@@ -603,7 +375,7 @@ TORCH_DOT_SUPPORTED = [x for x in TORCH_DOT_LIST if x in ALL_MAPPING]
TORCH_DOT_UNSUPPORTED
=
[
x
for
x
in
TORCH_DOT_LIST
if
x
not
in
ALL_MAPPING
]
TORCH_DOT_UNSUPPORTED
=
[
x
for
x
in
TORCH_DOT_LIST
if
x
not
in
ALL_MAPPING
]
TENSOR_DOT_LIST_PATH
=
os
.
path
.
realpath
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
'tensor_dot_list.json'
))
TENSOR_DOT_LIST_PATH
=
os
.
path
.
realpath
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
'
ops'
,
'
tensor_dot_list.json'
))
TENSOR_DOT_LIST
=
load_json_file
(
TENSOR_DOT_LIST_PATH
)
TENSOR_DOT_LIST
=
load_json_file
(
TENSOR_DOT_LIST_PATH
)
...
@@ -620,5 +392,5 @@ ALL_UNSUPPORTED = NN_UNSUPPORTED + F_UNSUPPORTED + TORCH_DOT_UNSUPPORTED + TENSO
...
@@ -620,5 +392,5 @@ ALL_UNSUPPORTED = NN_UNSUPPORTED + F_UNSUPPORTED + TORCH_DOT_UNSUPPORTED + TENSO
UNSUPPORTED_WARN_INFOS
=
{
UNSUPPORTED_WARN_INFOS
=
{
"nn.AdaptiveAvgPool2d"
:
"maybe could convert to P.ReduceMean"
,
"nn.AdaptiveAvgPool2d"
:
"maybe could convert to P.ReduceMean"
,
"F.adaptive_avg_pool2d"
:
"maybe could convert to P.ReduceMean"
,
"F.adaptive_avg_pool2d"
:
"maybe could convert to P.ReduceMean"
,
"F.dropout"
:
"please use nn.Dropout in __init__()"
,
"F.dropout"
:
"please use nn.Dropout in __init__()"
}
}
mindinsight/mindconverter/enums.py
已删除
100644 → 0
浏览文件 @
3abcb4a2
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Enums."""
from
enum
import
Enum
class
RequriedType
(
Enum
):
"""If param is required"""
REQUIRED
=
1
UNREQUIRED
=
2
mindinsight/mindconverter/funcs.py
0 → 100644
浏览文件 @
2f2f6a3f
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless REQUIRED by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""funcs for gen_explicit_map"""
from
functools
import
partial
def
gen_explicit_map_f_max_pool2d
(
params_pt
,
args_pt
):
"""
Generate explicit_map for F.MaxPool2d.
Args:
params_pt (dict): Params for APIPt.
args_pt (dict): Args for APIPt.
Returns:
dict, map between frames.
"""
if
'padding'
in
args_pt
:
padding
=
args_pt
[
'padding'
]
else
:
padding
=
params_pt
[
'padding'
]
if
padding
.
strip
()
in
(
"0"
,
"(0,0)"
,
"(0, 0)"
):
padding
=
"'valid'"
else
:
padding
=
"'same'"
return
{
"padding"
:
padding
}
def
gen_explicit_map_nn_sequential
(
_
,
args_pt
):
"""
Generate explicit_map for nn.Sequential.
Args:
args_pt (dict): Args for APIPt.
Returns:
dict, map between frames.
"""
args
=
args_pt
[
'*args'
]
return
{
"*args"
:
"[{}]"
.
format
(
args
)}
def
gen_explicit_map_one_delta
(
params_pt
,
args_pt
,
k_ms
,
k_pt
):
"""
Generate explicit_map for which include mapping relationship is `1 - k_ms = k_pt`.
Args:
params_pt (dict): Params for APIPt.
args_pt (dict): Args for APIPt.
Returns:
dict, map between frames.
"""
value
=
args_pt
[
k_pt
]
if
k_pt
in
args_pt
else
params_pt
[
k_pt
]
value
=
value
.
strip
()
def
is_number
(
string
):
try
:
float
(
string
)
return
True
except
ValueError
:
return
False
if
is_number
(
value
):
return
{
k_ms
:
str
(
1
-
float
(
value
))}
return
{
k_ms
:
"1.0 - "
+
value
}
def
gen_explicit_map_nn_maxpool2d
(
params_pt
,
args_pt
):
"""
Generate explicit_map for nn.MaxPool2d.
Args:
params_pt (dict): Params for APIPt.
args_pt (dict): Args for APIPt.
Returns:
dict, map between frames.
"""
if
'padding'
in
args_pt
:
padding
=
args_pt
[
'padding'
]
else
:
padding
=
params_pt
[
'padding'
]
if
padding
.
strip
()
in
(
"0"
,
"(0,0)"
,
"(0, 0)"
):
pad_mode
=
"'valid'"
else
:
pad_mode
=
"'same'"
return
{
"pad_mode"
:
pad_mode
}
tensor_dot_view_gen_explicit_map
=
lambda
params_pt
,
args_pt
:
{
"shape"
:
"("
+
args_pt
[
"*shape"
]
+
",)"
}
tensor_dot_reshape_gen_explicit_map
=
lambda
params_pt
,
args_pt
:
{
"shape"
:
"("
+
args_pt
[
"*shape"
]
+
",)"
}
nn_conv2d_gen_explicit_map
=
lambda
params_pt
,
args_pt
:
{
"pad_mode"
:
"'pad'"
}
nn_batchnorm2d_gen_explicit_map
=
partial
(
gen_explicit_map_one_delta
,
k_ms
=
"momentum"
,
k_pt
=
"momentum"
)
nn_dropout_gen_explicit_map
=
partial
(
gen_explicit_map_one_delta
,
k_ms
=
"keep_prob"
,
k_pt
=
"p"
)
mindinsight/mindconverter/mappings/f_mappings.json
0 → 100644
浏览文件 @
2f2f6a3f
{
"F.avg_pool2d"
:
{
"ms_api"
:
[
"P.AvgPool"
,
{
"ksize"
:
1
,
"strides"
:
1
,
"padding"
:
"valid"
,
"input"
:
"REQUIRED"
},
[
"ksize"
,
"strides"
,
"padding"
]
],
"pt_api"
:
[
"F.avg_pool2d"
,
{
"input"
:
"REQUIRED"
,
"kernel_size"
:
"REQUIRED"
,
"stride"
:
null
,
"padding"
:
0
,
"dilation"
:
1
,
"ceil_mode"
:
false
,
"return_indices"
:
false
}
],
"ms2pt_mapping"
:
{
"ksize"
:
"kernel_size"
,
"strides"
:
"stride"
,
"input"
:
"input"
},
"gen_explicit_map"
:
"gen_explicit_map_f_max_pool2d"
},
"F.max_pool2d"
:
{
"ms_api"
:
[
"P.MaxPool"
,
{
"ksize"
:
1
,
"strides"
:
1
,
"padding"
:
"valid"
,
"input"
:
"REQUIRED"
},
[
"ksize"
,
"strides"
,
"padding"
]
],
"pt_api"
:
[
"F.max_pool2d"
,
{
"input"
:
"REQUIRED"
,
"kernel_size"
:
"REQUIRED"
,
"stride"
:
null
,
"padding"
:
0
,
"dilation"
:
1
,
"ceil_mode"
:
false
,
"return_indices"
:
false
}
],
"ms2pt_mapping"
:
{
"ksize"
:
"kernel_size"
,
"strides"
:
"stride"
,
"input"
:
"input"
},
"gen_explicit_map"
:
"gen_explicit_map_f_max_pool2d"
},
"F.relu"
:
{
"ms_api"
:
[
"P.ReLU"
,
{
"input"
:
"REQUIRED"
}
],
"pt_api"
:
[
"F.relu"
,
{
"input"
:
"REQUIRED"
,
"inplace"
:
false
}
],
"ms2pt_mapping"
:
{
"input"
:
"input"
},
"gen_explicit_map"
:
null
},
"F.relu6"
:
{
"ms_api"
:
[
"P.ReLU6"
,
{
"input"
:
"REQUIRED"
}
],
"pt_api"
:
[
"F.relu6"
,
{
"input"
:
"REQUIRED"
,
"inplace"
:
false
}
],
"ms2pt_mapping"
:
{
"input"
:
"input"
},
"gen_explicit_map"
:
null
}
}
\ No newline at end of file
mindinsight/mindconverter/mappings/nn_mappings.json
0 → 100644
浏览文件 @
2f2f6a3f
{
"nn.Dropout"
:
{
"ms_api"
:
[
"nn.Dropout"
,
{
"keep_prob"
:
0.5
,
"seed0"
:
0
,
"seed1"
:
0
,
"dtype"
:
"mstype.float32"
}
],
"pt_api"
:
[
"nn.Dropout"
,
{
"p"
:
0.5
,
"inplace"
:
false
}
],
"ms2pt_mapping"
:
{
"keep_prob"
:
"p"
},
"gen_explicit_map"
:
"nn_dropout_gen_explicit_map"
},
"nn.AvgPool2d"
:
{
"ms_api"
:
[
"nn.AvgPool2d"
,
{
"kernel_size"
:
1
,
"stride"
:
1
,
"pad_mode"
:
"valid"
}
],
"pt_api"
:
[
"nn.AvgPool2d"
,
{
"kernel_size"
:
"REQUIRED"
,
"stride"
:
null
,
"padding"
:
0
,
"dilation"
:
1
,
"return_indices"
:
false
,
"ceil_mode"
:
"False"
}
],
"ms2pt_mapping"
:
{
"kernel_size"
:
"kernel_size"
,
"stride"
:
"stride"
},
"gen_explicit_map"
:
"gen_explicit_map_nn_maxpool2d"
},
"nn.MaxPool2d"
:
{
"ms_api"
:
[
"nn.MaxPool2d"
,
{
"kernel_size"
:
1
,
"stride"
:
1
,
"pad_mode"
:
"valid"
}
],
"pt_api"
:
[
"nn.MaxPool2d"
,
{
"kernel_size"
:
"REQUIRED"
,
"stride"
:
null
,
"padding"
:
0
,
"dilation"
:
1
,
"return_indices"
:
false
,
"ceil_mode"
:
"False"
}
],
"ms2pt_mapping"
:
{
"kernel_size"
:
"kernel_size"
,
"stride"
:
"stride"
},
"gen_explicit_map"
:
"gen_explicit_map_nn_maxpool2d"
},
"nn.Linear"
:
{
"ms_api"
:
[
"nn.Dense"
,
{
"in_channels"
:
"REQUIRED"
,
"out_channels"
:
"REQUIRED"
,
"weight_init"
:
"normal"
,
"bias_init"
:
"zeros"
,
"has_bias"
:
true
,
"activation"
:
null
}
],
"pt_api"
:
[
"nn.Linear"
,
{
"in_features"
:
"REQUIRED"
,
"out_features"
:
"REQUIRED"
,
"bias"
:
true
}
],
"ms2pt_mapping"
:
{
"in_channels"
:
"in_features"
,
"out_channels"
:
"out_features"
,
"has_bias"
:
"bias"
}
},
"nn.ReLU6"
:
{
"ms_api"
:
[
"nn.ReLU6"
,
{}
],
"pt_api"
:
[
"nn.ReLU6"
,
{
"inplace"
:
false
}
],
"ms2pt_mapping"
:
{}
},
"nn.ReLU"
:
{
"ms_api"
:
[
"nn.ReLU"
,
{}
],
"pt_api"
:
[
"F.relu"
,
{
"inplace"
:
false
}
],
"ms2pt_mapping"
:
{}
},
"nn.BatchNorm2d"
:
{
"ms_api"
:
[
"nn.BatchNorm2d"
,
{
"num_features"
:
"REQUIRED"
,
"eps"
:
1e-05
,
"momentum"
:
0.9
,
"affine"
:
true
,
"gamma_init"
:
"ones"
,
"beta_init"
:
"zeros"
,
"moving_mean_init"
:
"zeros"
,
"moving_var_init"
:
"ones"
,
"use_batch_statistics"
:
true
}
],
"pt_api"
:
[
"nn.BatchNorm2d"
,
{
"num_features"
:
"REQUIRED"
,
"eps"
:
1e-05
,
"momentum"
:
0.1
,
"affine"
:
true
,
"track_running_stats"
:
true
}
],
"ms2pt_mapping"
:
{
"num_features"
:
"num_features"
,
"eps"
:
"eps"
,
"affine"
:
"affine"
,
"use_batch_statistics"
:
"track_running_stats"
},
"gen_explicit_map"
:
"nn_batchnorm2d_gen_explicit_map"
},
"nn.Conv2d"
:
{
"ms_api"
:
[
"nn.Conv2d"
,
{
"in_channels"
:
"REQUIRED"
,
"out_channels"
:
"REQUIRED"
,
"kernel_size"
:
"REQUIRED"
,
"stride"
:
1
,
"pad_mode"
:
"same"
,
"padding"
:
0
,
"dilation"
:
1
,
"group"
:
1
,
"has_bias"
:
false
,
"weight_init"
:
"normal"
,
"bias_init"
:
"zeros"
}
],
"pt_api"
:
[
"nn.Conv2d"
,
{
"in_channels"
:
"REQUIRED"
,
"out_channels"
:
"REQUIRED"
,
"kernel_size"
:
"REQUIRED"
,
"stride"
:
1
,
"padding"
:
0
,
"dilation"
:
1
,
"groups"
:
1
,
"bias"
:
true
,
"padding_mode"
:
"zeros"
}
],
"ms2pt_mapping"
:
{
"in_channels"
:
"in_channels"
,
"out_channels"
:
"out_channels"
,
"kernel_size"
:
"kernel_size"
,
"stride"
:
"stride"
,
"padding"
:
"padding"
,
"dilation"
:
"dilation"
,
"group"
:
"groups"
,
"has_bias"
:
"bias"
},
"gen_explicit_map"
:
"nn_conv2d_gen_explicit_map"
},
"nn.Sequential"
:
{
"ms_api"
:
[
"nn.SequentialCell"
,
{
"*args"
:
" REQUIRED"
}
],
"pt_api"
:
[
"nn.Sequential"
,
{
"*args"
:
" REQUIRED"
}
],
"export_key"
:
false
,
"gen_explicit_map"
:
"gen_explicit_map_nn_sequential"
}
}
\ No newline at end of file
mindinsight/mindconverter/mappings/tensor_dot_mappings.json
0 → 100644
浏览文件 @
2f2f6a3f
{
".view"
:
{
"ms_api"
:
[
"P.Reshape"
,
{
"x"
:
"REQUIRED"
,
"shape"
:
"REQUIRED"
}
],
"pt_api"
:
[
".view"
,
{
"*shape"
:
"REQUIRED"
}
],
"ms2pt_mapping"
:
{
"x"
:
"call_name"
},
"gen_explicit_map"
:
"tensor_dot_view_gen_explicit_map"
},
".size"
:
{
"ms_api"
:
[
"P.Shape"
,
{
"x"
:
"REQUIRED"
}
],
"pt_api"
:
[
".size"
,
{
"idx"
:
"REQUIRED"
}
],
"ms2pt_mapping"
:
{
"x"
:
"call_name"
}
},
".flatten"
:
{
"ms_api"
:
[
"P.Flatten"
,
{
"input"
:
"REQUIRED"
}
],
"pt_api"
:
[
".flatten"
,
{
"start_dim"
:
0
,
"end_dim"
:
-1
}
],
"ms2pt_mapping"
:
{
"input"
:
"call_name"
}
},
".reshape"
:
{
"ms_api"
:
[
"P.Reshape"
,
{
"x"
:
"REQUIRED"
,
"shape"
:
"REQUIRED"
}
],
"pt_api"
:
[
".reshape"
,
{
"*shape"
:
"REQUIRED"
}
],
"ms2pt_mapping"
:
{
"x"
:
"call_name"
},
"gen_explicit_map"
:
"tensor_dot_reshape_gen_explicit_map"
},
".mean"
:
{
"ms_api"
:
[
"P.ReduceMean"
,
{
"keep_dims"
:
false
,
"input"
:
"REQUIRED"
,
"axis"
:
[]
}
],
"pt_api"
:
[
".mean"
,
{
"dim"
:
null
,
"keepdim"
:
false
}
],
"ms2pt_mapping"
:
{
"keep_dims"
:
"keepdim"
,
"axis"
:
"dim"
,
"input"
:
"call_name"
}
},
".squeeze"
:
{
"ms_api"
:
[
"P.ReduceMean"
,
{
"input"
:
"REQUIRED"
,
"axis"
:
[]
},
[
"axis"
]
],
"pt_api"
:
[
".squeeze"
,
{
"dim"
:
null
}
],
"ms2pt_mapping"
:
{
"axis"
:
"dim"
,
"input"
:
"call_name"
}
}
}
\ No newline at end of file
mindinsight/mindconverter/mappings/torch_dot_mappings.json
0 → 100644
浏览文件 @
2f2f6a3f
{
"torch.flatten"
:
{
"ms_api"
:
[
"P.Flatten"
,
{
"input"
:
"REQUIRED"
}
],
"pt_api"
:
[
"torch.flatten"
,
{
"input"
:
"REQUIRED"
,
"start_dim"
:
0
,
"end_dim"
:
-1
}
],
"ms2pt_mapping"
:
{
"input"
:
"input"
}
},
"torch.cat"
:
{
"ms_api"
:
[
"P.Concat"
,
{
"axis"
:
0
,
"input"
:
"REQUIRED"
},
[
"axis"
]
],
"pt_api"
:
[
"torch.cat"
,
{
"tensors"
:
"REQUIRED"
,
"dim"
:
0
,
"out"
:
null
}
],
"ms2pt_mapping"
:
{
"input"
:
"tensors"
,
"axis"
:
"dim"
}
}
}
\ No newline at end of file
mindinsight/mindconverter/f_list.json
→
mindinsight/mindconverter/
ops/
f_list.json
浏览文件 @
2f2f6a3f
文件已移动
mindinsight/mindconverter/nn_list.json
→
mindinsight/mindconverter/
ops/
nn_list.json
浏览文件 @
2f2f6a3f
文件已移动
mindinsight/mindconverter/tensor_dot_list.json
→
mindinsight/mindconverter/
ops/
tensor_dot_list.json
浏览文件 @
2f2f6a3f
文件已移动
mindinsight/mindconverter/torch_dot_list.json
→
mindinsight/mindconverter/
ops/
torch_dot_list.json
浏览文件 @
2f2f6a3f
文件已移动
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录