Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
慢慢CG
Mace
提交
97624cf2
Mace
项目概览
慢慢CG
/
Mace
与 Fork 源项目一致
Fork自
Xiaomi / Mace
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
Mace
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
97624cf2
编写于
10月 27, 2020
作者:
L
like15
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix: Extract Tensors from `params` instead of `state_dict` to be compatible with PyTorch 1.4
上级
9e6ee11c
变更
1
显示空白变更内容
内联
并排
Showing
1 changed file
with
31 addition
and
50 deletion
+31
-50
tools/python/transform/pytorch_converter.py
tools/python/transform/pytorch_converter.py
+31
-50
未找到文件。
tools/python/transform/pytorch_converter.py
浏览文件 @
97624cf2
...
...
@@ -26,15 +26,15 @@ from torch.onnx.utils import _node_getitem
from
py_proto
import
mace_pb2
from
transform
import
base_converter
from
transform.transformer
import
Transformer
from
transform.base_converter
import
PoolingType
from
transform.base_converter
import
ActivationType
from
transform.base_converter
import
EltwiseType
from
transform.base_converter
import
FrameworkType
from
transform.base_converter
import
MaceOp
from
transform.base_converter
import
MaceKeyword
from
transform.base_converter
import
ConverterUtil
from
transform.base_converter
import
RoundMode
from
transform.base_converter
import
DataFormat
from
transform.base_converter
import
MaceKeyword
from
transform.base_converter
import
MaceOp
from
transform.base_converter
import
PoolingType
from
transform.base_converter
import
RoundMode
from
utils.util
import
mace_check
...
...
@@ -48,7 +48,11 @@ def _model_to_graph(model, args):
in_vars
,
in_desc
=
torch
.
jit
.
_flatten
(
tuple
(
args
)
+
tuple
(
params
))
graph
=
_propagate_and_assign_input_shapes
(
method_graph
,
tuple
(
in_vars
),
False
,
propagate
)
return
graph
input_and_param_names
=
[
val
.
debugName
()
for
val
in
graph
.
inputs
()]
param_names
=
input_and_param_names
[
-
len
(
params
):]
params
=
[
elem
.
detach
()
for
elem
in
params
]
params_dict
=
dict
(
zip
(
param_names
,
params
))
return
graph
,
params_dict
class
ValueType
(
object
):
...
...
@@ -172,26 +176,8 @@ class PytorchConverter(base_converter.ConverterInterface):
else
:
dummy_input
=
dummy_input
+
(
torch
.
randn
(
in_node
.
shape
),)
graph
=
_model_to_graph
(
self
.
_loaded_model
,
dummy_input
)
state_dict
=
self
.
_loaded_model
.
state_dict
()
'''
num_batches_tracked in state_dict for BN layer is not used by any node,
delete them to avoid mistake name.
Maybe there are more unused keys in the future.
'''
unneeded_keys
=
[]
for
key
in
state_dict
.
keys
():
if
re
.
match
(
r
'.*\.num_batches_tracked'
,
key
):
unneeded_keys
.
append
(
key
)
for
key
in
unneeded_keys
:
del
state_dict
[
key
]
graph_inputs
=
list
(
graph
.
inputs
())
user_input_num
=
len
(
graph_inputs
)
-
len
(
state_dict
)
param_names
=
list
(
state_dict
.
keys
())
for
i
,
inp
in
enumerate
(
graph
.
inputs
()):
if
i
>=
user_input_num
:
inp
.
setDebugName
(
param_names
[
i
-
user_input_num
])
return
graph
graph
,
params_dict
=
_model_to_graph
(
self
.
_loaded_model
,
dummy_input
)
return
graph
,
params_dict
def
init_output_shape_cache
(
self
):
self
.
_output_shape_cache
=
{}
...
...
@@ -240,16 +226,13 @@ class PytorchConverter(base_converter.ConverterInterface):
}
self
.
_loaded_model
=
torch
.
jit
.
load
(
src_model_file
)
self
.
_loaded_model
.
eval
()
self
.
_graph
=
self
.
model_to_graph
()
self
.
_graph
,
self
.
_params_dict
=
self
.
model_to_graph
()
self
.
_output_node_name
=
list
(
self
.
_graph
.
outputs
())[
0
].
debugName
()
self
.
_output_value_type
=
list
(
self
.
_graph
.
outputs
())[
0
].
type
()
if
not
isinstance
(
self
.
_output_value_type
,
(
ValueType
.
TensorType
,
ValueType
.
ListType
,
ValueType
.
TupleType
)):
print
(
'return type {} not supported'
.
format
(
mace_check
(
isinstance
(
self
.
_output_value_type
,
(
ValueType
.
TensorType
,
ValueType
.
ListType
,
ValueType
.
TupleType
)),
'return type {} not supported'
.
format
(
self
.
_output_value_type
))
sys
.
exit
(
1
)
self
.
_node_map
=
{}
self
.
init_output_shape_cache
()
...
...
@@ -405,7 +388,7 @@ class PytorchConverter(base_converter.ConverterInterface):
# OIHW
key
=
inputs_vals
[
1
].
debugName
()
filter_shape
=
self
.
_
loaded_model
.
state_dict
()
[
key
].
shape
filter_shape
=
self
.
_
params_dict
[
key
].
shape
filter_shape
=
[
int
(
elem
)
for
elem
in
filter_shape
]
# Size -> list
mace_check
(
len
(
filter_shape
)
==
4
,
'MACE only supports 2D Conv, current Conv is {}D'
.
format
(
...
...
@@ -446,7 +429,7 @@ class PytorchConverter(base_converter.ConverterInterface):
dilation_arg
.
ints
.
extend
(
mace_dilations
)
filter_tensor_name
=
inputs_vals
[
ConvParamIdx
.
weight_idx
].
debugName
()
filter_data
=
self
.
_
loaded_model
.
state_dict
()
[
filter_tensor_name
]
filter_data
=
self
.
_
params_dict
[
filter_tensor_name
]
if
is_depthwise
:
# C1HW => 1CHW
filter_data
=
filter_data
.
permute
((
1
,
0
,
2
,
3
))
...
...
@@ -458,7 +441,7 @@ class PytorchConverter(base_converter.ConverterInterface):
has_bias
=
(
not
isinstance
(
bias_val
.
type
(),
ValueType
.
NoneType
))
if
has_bias
:
bias_tensor_name
=
inputs_vals
[
ConvParamIdx
.
bias_idx
].
debugName
()
bias_data
=
self
.
_
loaded_model
.
state_dict
()
[
bias_tensor_name
]
bias_data
=
self
.
_
params_dict
[
bias_tensor_name
]
bias_data
=
bias_data
.
numpy
()
self
.
add_tensor_and_shape
(
bias_tensor_name
,
bias_data
.
shape
,
mace_pb2
.
DT_FLOAT
,
bias_data
)
...
...
@@ -476,7 +459,7 @@ class PytorchConverter(base_converter.ConverterInterface):
mace_check
(
is_training
==
0
,
"Only support batch normalization with is_training = 0,"
" but got {}"
.
format
(
is_training
))
state_dict
=
self
.
_
loaded_model
.
state_dict
()
state_dict
=
self
.
_
params_dict
gamma_key
=
inputs_vals
[
BNParamIdx
.
weight_idx
].
debugName
()
gamma_value
=
state_dict
[
gamma_key
].
numpy
().
astype
(
np
.
float32
)
beta_key
=
inputs_vals
[
BNParamIdx
.
bias_idx
].
debugName
()
...
...
@@ -515,15 +498,13 @@ class PytorchConverter(base_converter.ConverterInterface):
type_arg
=
op
.
arg
.
add
()
type_arg
.
name
=
MaceKeyword
.
mace_activation_type_str
if
(
abs
(
max_val
-
6.
)
<
1e-8
):
mace_check
(
abs
(
max_val
-
6.
)
<
1e-8
,
'only support converting hardtanh_ to ReLU6 yet'
)
type_arg
.
s
=
six
.
b
(
self
.
activation_type
[
'ReLU6'
].
name
)
limit_arg
=
op
.
arg
.
add
()
limit_arg
.
name
=
MaceKeyword
.
mace_activation_max_limit_str
limit_arg
.
f
=
6.0
else
:
print
(
'only support converting hardtanh_ to ReLU6 yet'
)
sys
.
exit
(
1
)
self
.
infer_shape_general
(
op
)
def
convert_add
(
self
,
node
,
inputs_vals
,
outputs_vals
):
...
...
@@ -632,14 +613,14 @@ class PytorchConverter(base_converter.ConverterInterface):
def
get_weight_from_node
(
self
,
node
):
input_list
=
list
(
node
.
inputs
())
key
=
input_list
[
0
].
debugName
()
return
self
.
_
loaded_model
.
state_dict
()
[
key
]
return
self
.
_
params_dict
[
key
]
def
is_trans_fc_w
(
self
,
node
):
in_vals
=
list
(
node
.
inputs
())
mace_check
(
len
(
in_vals
)
==
1
,
't() must have 1 input'
)
in_name
=
in_vals
[
0
].
debugName
()
if
in_name
in
self
.
_
loaded_model
.
state_dict
()
and
\
len
(
self
.
_
loaded_model
.
state_dict
()
[
in_name
].
shape
)
==
2
:
if
in_name
in
self
.
_
params_dict
and
\
len
(
self
.
_
params_dict
[
in_name
].
shape
)
==
2
:
return
True
return
False
...
...
@@ -662,7 +643,7 @@ class PytorchConverter(base_converter.ConverterInterface):
alpha_type
=
inputs_vals
[
AddmmParamIdx
.
alpha_idx
].
type
()
is_alpha_fc
=
isinstance
(
alpha_type
,
ValueType
.
IntType
)
and
alpha
==
1
is_bias_w
=
inputs_vals
[
AddmmParamIdx
.
bias_idx
].
debugName
()
in
\
self
.
_
loaded_model
.
state_dict
()
self
.
_
params_dict
beta
=
inputs_vals
[
AddmmParamIdx
.
beta_idx
].
node
()[
'value'
]
beta_type
=
inputs_vals
[
AddmmParamIdx
.
beta_idx
].
type
()
is_beta_fc
=
isinstance
(
beta_type
,
ValueType
.
IntType
)
and
beta
==
1
...
...
@@ -703,7 +684,7 @@ class PytorchConverter(base_converter.ConverterInterface):
opb
.
type
=
MaceOp
.
BiasAdd
.
name
bias_tensor_name
=
opb
.
name
+
'_bias'
key
=
inputs_vals
[
AddmmParamIdx
.
bias_idx
].
debugName
()
bias_data
=
self
.
_
loaded_model
.
state_dict
()
[
key
]
bias_data
=
self
.
_
params_dict
[
key
]
bias_data
=
bias_data
.
numpy
()
self
.
add_tensor_and_shape
(
bias_tensor_name
,
bias_data
.
reshape
(
-
1
).
shape
,
mace_pb2
.
DT_FLOAT
,
bias_data
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录