Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MindSpore
mindinsight
提交
e569b96b
M
mindinsight
项目概览
MindSpore
/
mindinsight
通知
8
Star
4
Fork
2
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindinsight
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
e569b96b
编写于
6月 11, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
6月 11, 2020
浏览文件
操作
浏览文件
下载
差异文件
!249 add to support more ops for mindconverter
Merge pull request !249 from quyongxiu1/br_add_op_qyx
上级
7d4bdc72
8c0104f4
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
742 addition
and
3 deletion
+742
-3
mindinsight/mindconverter/config.py
mindinsight/mindconverter/config.py
+4
-2
mindinsight/mindconverter/funcs.py
mindinsight/mindconverter/funcs.py
+24
-1
mindinsight/mindconverter/mappings/f_mappings.json
mindinsight/mindconverter/mappings/f_mappings.json
+45
-0
mindinsight/mindconverter/mappings/nn_mappings.json
mindinsight/mindconverter/mappings/nn_mappings.json
+113
-0
mindinsight/mindconverter/mappings/tensor_dot_mappings.json
mindinsight/mindconverter/mappings/tensor_dot_mappings.json
+38
-0
mindinsight/mindconverter/mappings/torch_dot_mappings.json
mindinsight/mindconverter/mappings/torch_dot_mappings.json
+201
-0
tests/ut/mindconverter/test_converter.py
tests/ut/mindconverter/test_converter.py
+317
-0
未找到文件。
mindinsight/mindconverter/config.py
浏览文件 @
e569b96b
...
@@ -337,7 +337,7 @@ F_MAPPING_PATH = os.path.realpath(os.path.join(os.path.dirname(__file__), 'mappi
...
@@ -337,7 +337,7 @@ F_MAPPING_PATH = os.path.realpath(os.path.join(os.path.dirname(__file__), 'mappi
F_MAPPING
=
get_mapping_from_file
(
F_MAPPING_PATH
)
F_MAPPING
=
get_mapping_from_file
(
F_MAPPING_PATH
)
# update to add key starts with 'nn.functional.'
# update to add key starts with 'nn.functional.'
NN_FUNCTIONAL_D
=
{
"nn.functional."
+
k
[
len
(
'F.'
):]:
v
for
k
,
v
in
F_MAPPING
.
items
()}
NN_FUNCTIONAL_D
=
{
"nn.functional."
+
k
[
len
(
'F.'
):]:
v
for
k
,
v
in
F_MAPPING
.
items
()}
# update to add key starts with 'torch.nn.functiona
.l
'
# update to add key starts with 'torch.nn.functiona
l.
'
TORCH_NN_FUNCTIONAL_D
=
{
"torch.nn.functional."
+
k
[
len
(
'F.'
):]:
v
for
k
,
v
in
F_MAPPING
.
items
()}
TORCH_NN_FUNCTIONAL_D
=
{
"torch.nn.functional."
+
k
[
len
(
'F.'
):]:
v
for
k
,
v
in
F_MAPPING
.
items
()}
F_MAPPING
.
update
(
NN_FUNCTIONAL_D
)
F_MAPPING
.
update
(
NN_FUNCTIONAL_D
)
F_MAPPING
.
update
(
TORCH_NN_FUNCTIONAL_D
)
F_MAPPING
.
update
(
TORCH_NN_FUNCTIONAL_D
)
...
@@ -392,5 +392,7 @@ ALL_UNSUPPORTED = NN_UNSUPPORTED + F_UNSUPPORTED + TORCH_DOT_UNSUPPORTED + TENSO
...
@@ -392,5 +392,7 @@ ALL_UNSUPPORTED = NN_UNSUPPORTED + F_UNSUPPORTED + TORCH_DOT_UNSUPPORTED + TENSO
UNSUPPORTED_WARN_INFOS
=
{
UNSUPPORTED_WARN_INFOS
=
{
"nn.AdaptiveAvgPool2d"
:
"maybe could convert to P.ReduceMean"
,
"nn.AdaptiveAvgPool2d"
:
"maybe could convert to P.ReduceMean"
,
"F.adaptive_avg_pool2d"
:
"maybe could convert to P.ReduceMean"
,
"F.adaptive_avg_pool2d"
:
"maybe could convert to P.ReduceMean"
,
"F.dropout"
:
"please use nn.Dropout in __init__()"
"F.dropout"
:
"please use nn.Dropout in __init__()"
,
"torch.max"
:
"try to use P.ArgMaxWithValue, notice that two values are returned by P.ArgMaxWithValue"
,
"torch.min"
:
"try to use P.ArgMinWithValue, notice that two values are returned by P.ArgMinWithValue"
}
}
mindinsight/mindconverter/funcs.py
浏览文件 @
e569b96b
...
@@ -99,8 +99,31 @@ def gen_explicit_map_nn_maxpool2d(params_pt, args_pt):
...
@@ -99,8 +99,31 @@ def gen_explicit_map_nn_maxpool2d(params_pt, args_pt):
pad_mode
=
"'same'"
pad_mode
=
"'same'"
return
{
"pad_mode"
:
pad_mode
}
return
{
"pad_mode"
:
pad_mode
}
tensor_dot_view_gen_explicit_map
=
lambda
params_pt
,
args_pt
:
{
"shape"
:
"("
+
args_pt
[
"*shape"
]
+
",)"
}
def
torch_dot_eye_gen_explicit_map
(
_
,
args_pt
):
"""
Generate explicit_map for torch.eye.
Args:
args_pt (dict): Args for APIPt.
Returns:
dict, map between frames.
"""
explicit_map
=
{
't'
:
'mindspore.int32'
}
if
args_pt
.
get
(
'm'
):
explicit_map
.
update
({
'm'
:
args_pt
.
get
(
'm'
)})
else
:
explicit_map
.
update
({
'm'
:
args_pt
.
get
(
'n'
)})
return
explicit_map
tensor_dot_permute_gen_explicit_map
=
lambda
params_pt
,
args_pt
:
{
"input_perm"
:
"("
+
args_pt
[
"*dIms"
]
+
",)"
}
tensor_dot_repeat_gen_explicit_map
=
lambda
params_pt
,
args_pt
:
{
"multiples"
:
"("
+
args_pt
[
"*sizes"
]
+
",)"
}
tensor_dot_reshape_gen_explicit_map
=
lambda
params_pt
,
args_pt
:
{
"shape"
:
"("
+
args_pt
[
"*shape"
]
+
",)"
}
tensor_dot_reshape_gen_explicit_map
=
lambda
params_pt
,
args_pt
:
{
"shape"
:
"("
+
args_pt
[
"*shape"
]
+
",)"
}
tensor_dot_view_gen_explicit_map
=
lambda
params_pt
,
args_pt
:
{
"shape"
:
"("
+
args_pt
[
"*shape"
]
+
",)"
}
nn_conv2d_gen_explicit_map
=
lambda
params_pt
,
args_pt
:
{
"pad_mode"
:
"'pad'"
}
nn_conv2d_gen_explicit_map
=
lambda
params_pt
,
args_pt
:
{
"pad_mode"
:
"'pad'"
}
nn_batchnorm2d_gen_explicit_map
=
partial
(
gen_explicit_map_one_delta
,
k_ms
=
"momentum"
,
k_pt
=
"momentum"
)
nn_batchnorm2d_gen_explicit_map
=
partial
(
gen_explicit_map_one_delta
,
k_ms
=
"momentum"
,
k_pt
=
"momentum"
)
nn_batchnorm1d_gen_explicit_map
=
nn_batchnorm2d_gen_explicit_map
nn_dropout_gen_explicit_map
=
partial
(
gen_explicit_map_one_delta
,
k_ms
=
"keep_prob"
,
k_pt
=
"p"
)
nn_dropout_gen_explicit_map
=
partial
(
gen_explicit_map_one_delta
,
k_ms
=
"keep_prob"
,
k_pt
=
"p"
)
torch_dot_add_gen_explicit_map
=
lambda
params_pt
,
args_pt
:
\
{
"input_y"
:
(
args_pt
[
'value'
]
+
'*'
+
args_pt
[
"alpha"
])
if
args_pt
.
get
(
"alpha"
)
else
args_pt
[
'value'
]}
mindinsight/mindconverter/mappings/f_mappings.json
浏览文件 @
e569b96b
...
@@ -104,5 +104,50 @@
...
@@ -104,5 +104,50 @@
"input"
:
"input"
"input"
:
"input"
},
},
"gen_explicit_map"
:
null
"gen_explicit_map"
:
null
},
"F.normalize"
:
{
"ms_api"
:
[
"P.L2Normalize"
,
{
"axis"
:
0
,
"epsilon"
:
0.0001
,
"input_x"
:
"REQUIRED"
},
[
"axis"
,
"epsilon"
]
],
"pt_api"
:
[
"F.normalize"
,
{
"input"
:
"REQUIRED"
,
"p"
:
2
,
"dim"
:
1
,
"eps"
:
1e-12
}
],
"ms2pt_mapping"
:
{
"input_x"
:
"input"
,
"epsilon"
:
"eps"
,
"axis"
:
"dim"
}
},
"F.sigmoid"
:
{
"ms_api"
:
[
"P.Sigmoid"
,
{
"input_x"
:
"REQUIRED"
}
],
"pt_api"
:
[
"F.sigmoid"
,
{
"input"
:
"REQUIRED"
}
],
"ms2pt_mapping"
:
{
"input_x"
:
"input"
}
}
}
}
}
\ No newline at end of file
mindinsight/mindconverter/mappings/nn_mappings.json
浏览文件 @
e569b96b
...
@@ -216,5 +216,118 @@
...
@@ -216,5 +216,118 @@
],
],
"export_key"
:
false
,
"export_key"
:
false
,
"gen_explicit_map"
:
"gen_explicit_map_nn_sequential"
"gen_explicit_map"
:
"gen_explicit_map_nn_sequential"
},
"nn.BatchNorm1d"
:
{
"ms_api"
:
[
"nn.BatchNorm1d"
,
{
"num_features"
:
"REQUIRED"
,
"eps"
:
1e-05
,
"momentum"
:
0.9
,
"affine"
:
true
,
"gamma_init"
:
"ones"
,
"beta_init"
:
"zeros"
,
"moving_mean_init"
:
"zeros"
,
"moving_var_init"
:
"ones"
,
"use_batch_statistics"
:
true
}
],
"pt_api"
:
[
"nn.BatchNorm1d"
,
{
"num_features"
:
"REQUIRED"
,
"eps"
:
1e-05
,
"momentum"
:
0.1
,
"affine"
:
true
,
"track_running_stats"
:
true
}
],
"ms2pt_mapping"
:
{
"num_features"
:
"num_features"
,
"eps"
:
"eps"
,
"affine"
:
"affine"
,
"use_batch_statistics"
:
"track_running_stats"
},
"gen_explicit_map"
:
"nn_batchnorm1d_gen_explicit_map"
},
"nn.LayerNorm"
:
{
"ms_api"
:
[
"nn.LayerNorm"
,
{
"normalized_shape"
:
"REQUIRED"
,
"begin_norm_axis"
:
-1
,
"begin_params_axis"
:
-1
,
"gamma_init"
:
"ones"
,
"beta_init"
:
"zeros"
,
"epsilon"
:
1e-07
}
],
"pt_api"
:
[
"nn.LayerNorm"
,
{
"normalized_shape"
:
"REQUIRED"
,
"eps"
:
1e-05
,
"elementwise_affine"
:
true
}
],
"ms2pt_mapping"
:
{
"normalized_shape"
:
"normalized_shape"
,
"epsilon"
:
"eps"
}
},
"nn.LeakyReLU"
:
{
"ms_api"
:
[
"nn.LeakyReLU"
,
{
"alpha"
:
0.2
}
],
"pt_api"
:
[
"nn.LeakyReLU"
,
{
"negative_slope"
:
0.2
,
"inplace"
:
false
}
],
"ms2pt_mapping"
:
{
"alpha"
:
"negative_slope"
}
},
"nn.PReLU"
:
{
"ms_api"
:
[
"nn.PReLU"
,
{
"channel"
:
1
,
"w"
:
0.25
}
],
"pt_api"
:
[
"nn.PReLU"
,
{
"num_parameters"
:
1
,
"init"
:
0.25
}
],
"ms2pt_mapping"
:
{
"channel"
:
"num_parameters"
,
"w"
:
"init"
}
},
"nn.Softmax"
:
{
"ms_api"
:
[
"nn.Softmax"
,
{
"axis"
:
-1
}
],
"pt_api"
:
[
"nn.Softmax"
,
{
"dim"
:
"REQUIRED"
}
],
"ms2pt_mapping"
:
{
"axis"
:
"dim"
}
}
}
}
}
\ No newline at end of file
mindinsight/mindconverter/mappings/tensor_dot_mappings.json
浏览文件 @
e569b96b
...
@@ -115,5 +115,43 @@
...
@@ -115,5 +115,43 @@
"axis"
:
"dim"
,
"axis"
:
"dim"
,
"input"
:
"call_name"
"input"
:
"call_name"
}
}
},
".repeat"
:
{
"ms_api"
:
[
"P.Tile"
,
{
"input_x"
:
"REQUIRED"
,
"multiples"
:
"REQUIRED"
}
],
"pt_api"
:
[
".repeat"
,
{
"*sizes"
:
"REQUIRED"
}
],
"ms2pt_mapping"
:
{
"input_x"
:
"call_name"
},
"gen_explicit_map"
:
"tensor_dot_repeat_gen_explicit_map"
},
".permute"
:
{
"ms_api"
:
[
"P.Transpose"
,
{
"input_x"
:
"REQUIRED"
,
"input_perm"
:
"REQUIRED"
}
],
"pt_api"
:
[
".permute"
,
{
"*dIms"
:
"REQUIRED"
}
],
"ms2pt_mapping"
:
{
"input_x"
:
"call_name"
},
"gen_explicit_map"
:
"tensor_dot_permute_gen_explicit_map"
}
}
}
}
\ No newline at end of file
mindinsight/mindconverter/mappings/torch_dot_mappings.json
浏览文件 @
e569b96b
...
@@ -41,5 +41,206 @@
...
@@ -41,5 +41,206 @@
"input"
:
"tensors"
,
"input"
:
"tensors"
,
"axis"
:
"dim"
"axis"
:
"dim"
}
}
},
"torch.abs"
:
{
"ms_api"
:
[
"P.Abs"
,
{
"input_x"
:
"REQUIRED"
}
],
"pt_api"
:
[
".abs"
,
{
"input"
:
"REQUIRED"
}
],
"ms2pt_mapping"
:
{
"input_x"
:
"input"
}
},
"torch.acos"
:
{
"ms_api"
:
[
"P.ACos"
,
{
"input_x"
:
"REQUIRED"
}
],
"pt_api"
:
[
".acos"
,
{
"input"
:
"REQUIRED"
}
],
"ms2pt_mapping"
:
{
"input_x"
:
"input"
}
},
"torch.cos"
:
{
"ms_api"
:
[
"P.Cos"
,
{
"input_x"
:
"REQUIRED"
}
],
"pt_api"
:
[
".cos"
,
{
"input"
:
"REQUIRED"
}
],
"ms2pt_mapping"
:
{
"input_x"
:
"input"
}
},
"torch.exp"
:
{
"ms_api"
:
[
"P.Exp"
,
{
"input_x"
:
"REQUIRED"
}
],
"pt_api"
:
[
".exp"
,
{
"input"
:
"REQUIRED"
}
],
"ms2pt_mapping"
:
{
"input_x"
:
"input"
}
},
"torch.log"
:
{
"ms_api"
:
[
"P.Log"
,
{
"input_x"
:
"REQUIRED"
}
],
"pt_api"
:
[
".log"
,
{
"input"
:
"REQUIRED"
}
],
"ms2pt_mapping"
:
{
"input_x"
:
"input"
}
},
"torch.pow"
:
{
"ms_api"
:
[
"P.Pow"
,
{
"input_x"
:
"REQUIRED"
,
"input_y"
:
"REQUIRED"
}
],
"pt_api"
:
[
".pow"
,
{
"input"
:
"REQUIRED"
,
"exponent"
:
"REQUIRED"
}
],
"ms2pt_mapping"
:
{
"input_x"
:
"input"
,
"input_y"
:
"exponent"
}
},
"torch.div"
:
{
"ms_api"
:
[
"P.Div"
,
{
"input_x"
:
"REQUIRED"
,
"input_y"
:
"REQUIRED"
}
],
"pt_api"
:
[
".div"
,
{
"input"
:
"REQUIRED"
,
"other"
:
"REQUIRED"
}
],
"ms2pt_mapping"
:
{
"input_x"
:
"input"
,
"input_y"
:
"other"
}
},
"torch.sin"
:
{
"ms_api"
:
[
"P.Sin"
,
{
"input_x"
:
"REQUIRED"
}
],
"pt_api"
:
[
".sin"
,
{
"input"
:
"REQUIRED"
}
],
"ms2pt_mapping"
:
{
"input_x"
:
"input"
}
},
"torch.sqrt"
:
{
"ms_api"
:
[
"P.Sqrt"
,
{
"input_x"
:
"REQUIRED"
}
],
"pt_api"
:
[
".sqrt"
,
{
"input"
:
"REQUIRED"
}
],
"ms2pt_mapping"
:
{
"input_x"
:
"input"
}
},
"torch.add"
:
{
"ms_api"
:
[
"P.TensorAdd"
,
{
"input_x"
:
"REQUIRED"
,
"input_y"
:
"REQUIRED"
}
],
"pt_api"
:
[
".add"
,
{
"input"
:
"REQUIRED"
,
"value"
:
"REQUIRED"
,
"alpha"
:
1
}
],
"ms2pt_mapping"
:
{
"input_x"
:
"input"
},
"gen_explicit_map"
:
"torch_dot_add_gen_explicit_map"
},
"torch.eye"
:
{
"ms_api"
:
[
"P.Eye"
,
{
"n"
:
"REQUIRED"
,
"m"
:
"REQUIRED"
,
"t"
:
"REQUIRED"
}
],
"pt_api"
:
[
".eye"
,
{
"n"
:
"REQUIRED"
,
"m"
:
"REQUIRED"
}
],
"ms2pt_mapping"
:
{
"n"
:
"n"
},
"gen_explicit_map"
:
"torch_dot_eye_gen_explicit_map"
}
}
}
}
\ No newline at end of file
tests/ut/mindconverter/test_converter.py
浏览文件 @
e569b96b
...
@@ -14,6 +14,7 @@
...
@@ -14,6 +14,7 @@
# ============================================================================
# ============================================================================
"""Test Converter"""
"""Test Converter"""
from
mindinsight.mindconverter.converter
import
Converter
from
mindinsight.mindconverter.converter
import
Converter
from
mindinsight.mindconverter.config
import
NN_MAPPING
class
TestConverter
:
class
TestConverter
:
...
@@ -82,3 +83,319 @@ class TestConverter:
...
@@ -82,3 +83,319 @@ class TestConverter:
result
=
self
.
converter_ins
.
find_right_parentheses
(
code
,
left_index
)
result
=
self
.
converter_ins
.
find_right_parentheses
(
code
,
left_index
)
assert_index
=
len
(
code
)
-
1
assert_index
=
len
(
code
)
-
1
assert
result
==
assert_index
assert
result
==
assert_index
# test convert_api with nn ops
def
test_convert_api_nn_layernorm
(
self
):
"""Test convert_api function work ok when convert api nn.LayerNorm"""
code
=
"""
def __init__(self, num_classes=1000):
self.features = nn.SequentialCell([
nn.LayerNorm((5, 10, 10), elementwise_affine=False),
nn.ReLU(inplace=False)
])
"""
api_name
=
'nn.LayerNorm'
start
=
code
.
find
(
api_name
)
layer_norm_info
=
NN_MAPPING
.
get
(
api_name
)
expected_ms_api_name
=
'nn.LayerNorm'
epsilon
=
layer_norm_info
.
pt_api
.
params
.
get
(
'eps'
)
replaced_code
,
new_start
=
self
.
converter_ins
.
convert_api
(
code
,
start
,
api_name
)
assert
replaced_code
==
code
.
replace
(
'nn.LayerNorm((5, 10, 10), elementwise_affine=False)'
,
'{}(normalized_shape=(5, 10, 10), epsilon={})'
.
format
(
expected_ms_api_name
,
epsilon
))
assert
new_start
==
start
+
len
(
expected_ms_api_name
)
def
test_convert_api_nn_leaky_relu
(
self
):
"""Test convert_api function work ok when convert api nn.LeakyReLU"""
code
=
"""
def __init__(self, num_classes=1000):
self.features = nn.SequentialCell([
nn.LayerNorm((5, 10, 10), elementwise_affine=False),
nn.LeakyReLU(0.3)])
"""
api_name
=
'nn.LeakyReLU'
start
=
code
.
find
(
api_name
)
expected_ms_api_name
=
'nn.LeakyReLU'
replaced_code
,
new_start
=
self
.
converter_ins
.
convert_api
(
code
,
start
,
api_name
)
assert
replaced_code
==
code
.
replace
(
'nn.LeakyReLU(0.3)'
,
'{}(alpha=0.3)'
.
format
(
expected_ms_api_name
))
assert
new_start
==
start
+
len
(
expected_ms_api_name
)
def
test_convert_api_nn_prelu
(
self
):
"""Test convert_api function work ok when convert api nn.PReLU"""
code
=
"""
input = torch.randn(2, 3, 5)
nn.PReLU()(input)
"""
api_name
=
'nn.PReLU'
start
=
code
.
find
(
api_name
)
expected_ms_api_name
=
'nn.PReLU'
replaced_code
,
new_start
=
self
.
converter_ins
.
convert_api
(
code
,
start
,
api_name
)
assert
replaced_code
==
code
.
replace
(
'nn.PReLU()(input)'
,
'{}()(input)'
.
format
(
expected_ms_api_name
))
assert
new_start
==
start
+
len
(
expected_ms_api_name
)
def
test_convert_api_nn_softmax
(
self
):
"""Test convert_api function work ok when convert api nn.Softmax"""
code
=
"""
nn.Softmax(dim=1)(input)
"""
api_name
=
'nn.Softmax'
expected_ms_api_name
=
'nn.Softmax'
start
=
code
.
find
(
api_name
)
replaced_code
,
new_start
=
self
.
converter_ins
.
convert_api
(
code
,
start
,
api_name
)
assert
replaced_code
==
code
.
replace
(
'nn.Softmax(dim=1)(input)'
,
'{}(axis=1)(input)'
.
format
(
expected_ms_api_name
))
assert
new_start
==
start
+
len
(
expected_ms_api_name
)
# test convert_api with torch dot ops
def
test_convert_api_torch_dot_abs
(
self
):
"""Test convert_api function work ok when convert api torch.abs"""
code
=
"""
torch.abs(input)
"""
api_name
=
'torch.abs'
start
=
code
.
find
(
api_name
)
expected_ms_api_name
=
'P.Abs'
replaced_code
,
new_start
=
self
.
converter_ins
.
convert_api
(
code
,
start
,
api_name
)
assert
replaced_code
==
code
.
replace
(
'torch.abs(input)'
,
'{}()(input)'
.
format
(
expected_ms_api_name
))
assert
new_start
==
start
+
len
(
expected_ms_api_name
)
def
test_convert_api_torch_dot_acos
(
self
):
"""Test convert_api function work ok when convert api torch.acos"""
code
=
"""
torch.acos(input)
"""
api_name
=
'torch.acos'
start
=
code
.
find
(
api_name
)
expected_ms_api_name
=
'P.ACos'
replaced_code
,
new_start
=
self
.
converter_ins
.
convert_api
(
code
,
start
,
api_name
)
assert
replaced_code
==
code
.
replace
(
'torch.acos(input)'
,
'{}()(input)'
.
format
(
expected_ms_api_name
))
assert
new_start
==
start
+
len
(
expected_ms_api_name
)
def
test_convert_api_torch_dot_cos
(
self
):
"""Test convert_api function work ok when convert api torch.cos"""
code
=
"""
torch.cos(input)
"""
api_name
=
'torch.cos'
expected_ms_api_name
=
'P.Cos'
start
=
code
.
find
(
api_name
)
replaced_code
,
new_start
=
self
.
converter_ins
.
convert_api
(
code
,
start
,
api_name
)
assert
replaced_code
==
code
.
replace
(
'torch.cos(input)'
,
'{}()(input)'
.
format
(
expected_ms_api_name
))
assert
new_start
==
start
+
len
(
expected_ms_api_name
)
def
test_convert_api_torch_dot_exp
(
self
):
"""Test convert_api function work ok when convert api torch.exp"""
code
=
"""
torch.exp(input)
"""
api_name
=
'torch.exp'
expected_ms_api_name
=
'P.Exp'
start
=
code
.
find
(
api_name
)
replaced_code
,
new_start
=
self
.
converter_ins
.
convert_api
(
code
,
start
,
api_name
)
assert
replaced_code
==
code
.
replace
(
'torch.exp(input)'
,
'{}()(input)'
.
format
(
expected_ms_api_name
))
assert
new_start
==
start
+
len
(
expected_ms_api_name
)
def
test_convert_api_torch_dot_log
(
self
):
"""Test convert_api function work ok when convert api torch.log"""
code
=
"""
torch.log(input)
"""
api_name
=
'torch.log'
expected_ms_api_name
=
'P.Log'
start
=
code
.
find
(
api_name
)
replaced_code
,
new_start
=
self
.
converter_ins
.
convert_api
(
code
,
start
,
api_name
)
assert
replaced_code
==
code
.
replace
(
'torch.log(input)'
,
'{}()(input)'
.
format
(
expected_ms_api_name
))
assert
new_start
==
start
+
len
(
expected_ms_api_name
)
def
test_convert_api_torch_dot_pow
(
self
):
"""Test convert_api function work ok when convert api torch.pow"""
code
=
"""
torch.pow(a, exp)
"""
api_name
=
'torch.pow'
expected_ms_api_name
=
'P.Pow'
start
=
code
.
find
(
api_name
)
replaced_code
,
new_start
=
self
.
converter_ins
.
convert_api
(
code
,
start
,
api_name
)
assert
replaced_code
==
code
.
replace
(
'torch.pow(a, exp)'
,
'{}()(a, exp)'
.
format
(
expected_ms_api_name
))
assert
new_start
==
start
+
len
(
expected_ms_api_name
)
def
test_convert_api_torch_dot_div
(
self
):
"""Test convert_api function work ok when convert api torch.div"""
code
=
"""
input = torch.randn(5)
other = torch.randn(5)
torch.div(input, other)
"""
api_name
=
'torch.div'
expected_ms_api_name
=
'P.Div'
start
=
code
.
find
(
api_name
)
replaced_code
,
new_start
=
self
.
converter_ins
.
convert_api
(
code
,
start
,
api_name
)
assert
replaced_code
==
code
.
replace
(
'torch.div(input, other)'
,
'{}()(input, other)'
.
format
(
expected_ms_api_name
))
assert
new_start
==
start
+
len
(
expected_ms_api_name
)
def
test_convert_api_torch_dot_sin
(
self
):
"""Test convert_api function work ok when convert api torch.sin"""
code
=
"""
torch.sin(input)
"""
api_name
=
'torch.sin'
expected_ms_api_name
=
'P.Sin'
start
=
code
.
find
(
api_name
)
replaced_code
,
new_start
=
self
.
converter_ins
.
convert_api
(
code
,
start
,
api_name
)
assert
replaced_code
==
code
.
replace
(
'torch.sin(input)'
,
'{}()(input)'
.
format
(
expected_ms_api_name
))
assert
new_start
==
start
+
len
(
expected_ms_api_name
)
def
test_convert_api_torch_dot_sqrt
(
self
):
"""Test convert_api function work ok when convert api torch.sqrt"""
code
=
"""
torch.sqrt(input)
"""
api_name
=
'torch.sqrt'
expected_ms_api_name
=
'P.Sqrt'
start
=
code
.
find
(
api_name
)
replaced_code
,
new_start
=
self
.
converter_ins
.
convert_api
(
code
,
start
,
api_name
)
assert
replaced_code
==
code
.
replace
(
'torch.sqrt(input)'
,
'{}()(input)'
.
format
(
expected_ms_api_name
))
assert
new_start
==
start
+
len
(
expected_ms_api_name
)
def
test_convert_api_torch_dot_eye_with_n
(
self
):
"""Test convert_api function work ok when convert api torch.eye"""
code
=
"""
torch.eye(3)
"""
api_name
=
'torch.eye'
expected_ms_api_name
=
'P.Eye'
start
=
code
.
find
(
api_name
)
replaced_code
,
new_start
=
self
.
converter_ins
.
convert_api
(
code
,
start
,
api_name
)
assert
replaced_code
==
code
.
replace
(
'torch.eye(3)'
,
'{}()(3, 3, mindspore.int32)'
.
format
(
expected_ms_api_name
))
assert
new_start
==
start
+
len
(
expected_ms_api_name
)
def
test_convert_api_torch_dot_eye_with_m
(
self
):
"""Test convert_api function work ok when convert api torch.eye"""
code
=
"""
torch.eye(3, 4)
"""
api_name
=
'torch.eye'
expected_ms_api_name
=
'P.Eye'
start
=
code
.
find
(
api_name
)
replaced_code
,
new_start
=
self
.
converter_ins
.
convert_api
(
code
,
start
,
api_name
)
assert
replaced_code
==
code
.
replace
(
'torch.eye(3, 4)'
,
'{}()(3, 4, mindspore.int32)'
.
format
(
expected_ms_api_name
))
assert
new_start
==
start
+
len
(
expected_ms_api_name
)
def
test_convert_api_torch_dot_add_with_alpha_default
(
self
):
"""Test convert_api function work ok when convert api torch.add"""
code
=
"""
torch.add(input, value)
"""
api_name
=
'torch.add'
expected_ms_api_name
=
'P.TensorAdd'
start
=
code
.
find
(
api_name
)
replaced_code
,
new_start
=
self
.
converter_ins
.
convert_api
(
code
,
start
,
api_name
)
assert
replaced_code
==
code
.
replace
(
'torch.add(input, value)'
,
'{}()(input, value)'
.
format
(
expected_ms_api_name
))
assert
new_start
==
start
+
len
(
expected_ms_api_name
)
def
test_convert_api_torch_dot_add_with_alpha_not_default
(
self
):
"""Test convert_api function work ok when convert api torch.add"""
code
=
"""
torch.add(input, value, 3)
"""
api_name
=
'torch.add'
expected_ms_api_name
=
'P.TensorAdd'
start
=
code
.
find
(
api_name
)
replaced_code
,
new_start
=
self
.
converter_ins
.
convert_api
(
code
,
start
,
api_name
)
assert
replaced_code
==
code
.
replace
(
'torch.add(input, value, 3)'
,
'{}()(input, value*3)'
.
format
(
expected_ms_api_name
))
assert
new_start
==
start
+
len
(
expected_ms_api_name
)
# test convert_api with F ops
def
test_convert_api_f_normalize
(
self
):
"""Test convert_api function work ok when convert api F.normalize"""
code
=
"""
input = torch.randn(2, 3, 5)
F.normalize(input)
"""
api_name
=
'F.normalize'
start
=
code
.
find
(
api_name
)
expected_ms_api_name
=
'P.L2Normalize'
replaced_code
,
new_start
=
self
.
converter_ins
.
convert_api
(
code
,
start
,
api_name
)
assert
replaced_code
==
code
.
replace
(
'F.normalize(input)'
,
'{}(1, 1e-12)(input)'
.
format
(
expected_ms_api_name
))
assert
new_start
==
start
+
len
(
expected_ms_api_name
)
def
test_convert_api_f_sigmoid
(
self
):
"""Test convert_api function work ok when convert api F.sigmoid"""
code
=
"""
input = torch.randn(2, 3, 5)
F.sigmoid(input)
"""
api_name
=
'F.sigmoid'
start
=
code
.
find
(
api_name
)
expected_ms_api_name
=
'P.Sigmoid'
replaced_code
,
new_start
=
self
.
converter_ins
.
convert_api
(
code
,
start
,
api_name
)
assert
replaced_code
==
code
.
replace
(
'F.sigmoid(input)'
,
'{}()(input)'
.
format
(
expected_ms_api_name
))
assert
new_start
==
start
+
len
(
expected_ms_api_name
)
# test convert_api with tensor dot ops
def
test_convert_api_tensor_dot_repeat
(
self
):
"""Test convert_api function work ok when convert api .repeat"""
code
=
"""
x.repeat(4, 2)
"""
api_name
=
'.repeat'
start
=
code
.
find
(
api_name
)
expected_ms_api_name
=
'P.Tile'
replaced_code
,
new_start
=
self
.
converter_ins
.
convert_api
(
code
,
start
,
api_name
)
assert
replaced_code
==
code
.
replace
(
'x.repeat(4, 2)'
,
'{}()(x, {})'
.
format
(
expected_ms_api_name
,
'(4, 2,)'
))
assert
new_start
==
start
+
len
(
expected_ms_api_name
)
def
test_convert_api_tensor_dot_permute
(
self
):
"""Test convert_api function work ok when convert api .permute"""
code
=
"""
x.permute(2, 0, 1)
"""
api_name
=
'.permute'
start
=
code
.
find
(
api_name
)
expected_ms_api_name
=
'P.Transpose'
replaced_code
,
new_start
=
self
.
converter_ins
.
convert_api
(
code
,
start
,
api_name
)
assert
replaced_code
==
code
.
replace
(
'x.permute(2, 0, 1)'
,
'{}()(x, (2, 0, 1,))'
.
format
(
expected_ms_api_name
))
assert
new_start
==
start
+
len
(
expected_ms_api_name
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录