Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
项目经理老王
Mace
提交
9efe5dc5
Mace
项目概览
项目经理老王
/
Mace
与 Fork 源项目一致
Fork自
Xiaomi / Mace
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
Mace
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
9efe5dc5
编写于
9月 30, 2020
作者:
L
like15
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat: Support Densenet and MobilenetV2 conversion from PyTorch
上级
e1f4fd86
变更
12
展开全部
隐藏空白更改
内联
并排
Showing
12 changed file
with
1160 addition
and
14 deletion
+1160
-14
mace/core/types.h
mace/core/types.h
+2
-0
tools/converter.py
tools/converter.py
+8
-0
tools/device.py
tools/device.py
+3
-0
tools/python/convert.py
tools/python/convert.py
+4
-0
tools/python/transform/base_converter.py
tools/python/transform/base_converter.py
+1
-0
tools/python/transform/pytorch_converter.py
tools/python/transform/pytorch_converter.py
+1000
-0
tools/python/transform/transformer.py
tools/python/transform/transformer.py
+38
-10
tools/python/utils/config_parser.py
tools/python/utils/config_parser.py
+1
-0
tools/python/utils/device.py
tools/python/utils/device.py
+2
-2
tools/python/validate.py
tools/python/validate.py
+49
-0
tools/sh_commands.py
tools/sh_commands.py
+3
-2
tools/validate.py
tools/validate.py
+49
-0
未找到文件。
mace/core/types.h
浏览文件 @
9efe5dc5
...
@@ -69,6 +69,8 @@ enum FrameworkType {
...
@@ -69,6 +69,8 @@ enum FrameworkType {
TENSORFLOW
=
0
,
TENSORFLOW
=
0
,
CAFFE
=
1
,
CAFFE
=
1
,
ONNX
=
2
,
ONNX
=
2
,
MEGENGINE
=
3
,
PYTORCH
=
4
};
};
template
<
typename
T
>
template
<
typename
T
>
...
...
tools/converter.py
浏览文件 @
9efe5dc5
...
@@ -61,6 +61,7 @@ PlatformTypeStrs = [
...
@@ -61,6 +61,7 @@ PlatformTypeStrs = [
"caffe"
,
"caffe"
,
"onnx"
,
"onnx"
,
"megengine"
,
"megengine"
,
"pytorch"
,
]
]
PlatformType
=
Enum
(
'PlatformType'
,
[(
ele
,
ele
)
for
ele
in
PlatformTypeStrs
],
PlatformType
=
Enum
(
'PlatformType'
,
[(
ele
,
ele
)
for
ele
in
PlatformTypeStrs
],
type
=
str
)
type
=
str
)
...
@@ -520,6 +521,13 @@ def format_model_config(flags):
...
@@ -520,6 +521,13 @@ def format_model_config(flags):
if
not
isinstance
(
value
,
list
):
if
not
isinstance
(
value
,
list
):
subgraph
[
key
]
=
[
value
]
subgraph
[
key
]
=
[
value
]
subgraph
[
key
]
=
[
str
(
v
)
for
v
in
subgraph
[
key
]]
subgraph
[
key
]
=
[
str
(
v
)
for
v
in
subgraph
[
key
]]
# --inputs_shapes will be passed to ELF file `mace_run_static', if input_shapes
# contains spaces, such as: '1, 3, 224, 224', because mace_run.cc use gflags to
# parse command line arguments, --input_shapes 1, 3, 224, 224 will be passed as
# `--input_shapes 1,'. So we strip out spaces here.
if
key
in
[
YAMLKeyword
.
input_shapes
,
YAMLKeyword
.
output_shapes
]:
subgraph
[
key
]
=
[
e
.
replace
(
' '
,
''
)
for
e
in
subgraph
[
key
]]
input_size
=
len
(
subgraph
[
YAMLKeyword
.
input_tensors
])
input_size
=
len
(
subgraph
[
YAMLKeyword
.
input_tensors
])
output_size
=
len
(
subgraph
[
YAMLKeyword
.
output_tensors
])
output_size
=
len
(
subgraph
[
YAMLKeyword
.
output_tensors
])
...
...
tools/device.py
浏览文件 @
9efe5dc5
...
@@ -632,6 +632,9 @@ class DeviceWrapper:
...
@@ -632,6 +632,9 @@ class DeviceWrapper:
'Run model {} on {}'
.
format
(
model_name
,
self
.
device_name
)))
'Run model {} on {}'
.
format
(
model_name
,
self
.
device_name
)))
model_config
=
configs
[
YAMLKeyword
.
models
][
model_name
]
model_config
=
configs
[
YAMLKeyword
.
models
][
model_name
]
if
model_config
[
YAMLKeyword
.
platform
]
==
'pytorch'
:
mace_check
(
flags
.
layers
==
"-1"
,
"Device"
,
'extracting intermediate layer output is not supported in pytorch JIT yet'
)
# noqa
model_runtime
=
model_config
[
YAMLKeyword
.
runtime
]
model_runtime
=
model_config
[
YAMLKeyword
.
runtime
]
subgraphs
=
model_config
[
YAMLKeyword
.
subgraphs
]
subgraphs
=
model_config
[
YAMLKeyword
.
subgraphs
]
...
...
tools/python/convert.py
浏览文件 @
9efe5dc5
...
@@ -190,6 +190,10 @@ def convert_model(conf, quantize_stat):
...
@@ -190,6 +190,10 @@ def convert_model(conf, quantize_stat):
from
transform
import
megengine_converter
from
transform
import
megengine_converter
converter
=
megengine_converter
.
MegengineConverter
(
converter
=
megengine_converter
.
MegengineConverter
(
option
,
conf
[
"model_file_path"
])
option
,
conf
[
"model_file_path"
])
elif
platform
==
Platform
.
PYTORCH
:
from
transform
import
pytorch_converter
converter
=
pytorch_converter
.
PytorchConverter
(
option
,
conf
[
"model_file_path"
])
else
:
else
:
mace_check
(
False
,
"Mace do not support platorm %s yet."
%
platform
)
mace_check
(
False
,
"Mace do not support platorm %s yet."
%
platform
)
...
...
tools/python/transform/base_converter.py
浏览文件 @
9efe5dc5
...
@@ -88,6 +88,7 @@ class FrameworkType(Enum):
...
@@ -88,6 +88,7 @@ class FrameworkType(Enum):
CAFFE
=
1
CAFFE
=
1
ONNX
=
2
ONNX
=
2
MEGENGINE
=
3
MEGENGINE
=
3
PYTORCH
=
4
MaceSupportedOps
=
[
MaceSupportedOps
=
[
...
...
tools/python/transform/pytorch_converter.py
0 → 100644
浏览文件 @
9efe5dc5
此差异已折叠。
点击以展开。
tools/python/transform/transformer.py
浏览文件 @
9efe5dc5
...
@@ -345,6 +345,7 @@ class Transformer(base_converter.ConverterInterface):
...
@@ -345,6 +345,7 @@ class Transformer(base_converter.ConverterInterface):
input_info
.
dims
.
extend
(
input_node
.
shape
)
input_info
.
dims
.
extend
(
input_node
.
shape
)
input_info
.
data_type
=
input_node
.
data_type
input_info
.
data_type
=
input_node
.
data_type
# tools/python/convert.py sets option.check_nodes
output_nodes
=
self
.
_option
.
check_nodes
.
values
()
output_nodes
=
self
.
_option
.
check_nodes
.
values
()
for
output_node
in
output_nodes
:
for
output_node
in
output_nodes
:
output_info
=
net
.
output_info
.
add
()
output_info
=
net
.
output_info
.
add
()
...
@@ -1312,12 +1313,18 @@ class Transformer(base_converter.ConverterInterface):
...
@@ -1312,12 +1313,18 @@ class Transformer(base_converter.ConverterInterface):
for
op
in
net
.
op
:
for
op
in
net
.
op
:
# transform `input(4D) -> reshape(2D) -> matmul` to `fc(2D)`
# transform `input(4D) -> reshape(2D) -> matmul` to `fc(2D)`
# fc output is 2D in transformer, using as 4D in op kernel
# fc output is 2D in transformer, using as 4D in op kernel
# work for TensorFlow
# work for TensorFlow/PyTorch/ONNX
framework
=
ConverterUtil
.
get_arg
(
op
,
MaceKeyword
.
mace_framework_type_str
).
i
is_torch
=
framework
==
FrameworkType
.
PYTORCH
.
value
is_tf
=
framework
==
FrameworkType
.
TENSORFLOW
.
value
is_onnx
=
framework
==
FrameworkType
.
ONNX
.
value
if
op
.
type
==
MaceOp
.
Reshape
.
name
and
\
if
op
.
type
==
MaceOp
.
Reshape
.
name
and
\
len
(
op
.
input
)
==
2
and
\
len
(
op
.
input
)
==
2
and
\
op
.
input
[
1
]
in
self
.
_consts
and
\
op
.
input
[
1
]
in
self
.
_consts
and
\
len
(
op
.
output_shape
[
0
].
dims
)
==
2
and
\
len
(
op
.
output_shape
[
0
].
dims
)
==
2
and
\
filter_format
==
DataFormat
.
HWIO
and
\
(
is_tf
or
is_torch
or
is_onnx
)
and
\
op
.
input
[
0
]
in
self
.
_producer
:
op
.
input
[
0
]
in
self
.
_producer
:
input_op
=
self
.
_producer
[
op
.
input
[
0
]]
input_op
=
self
.
_producer
[
op
.
input
[
0
]]
input_shape
=
input_op
.
output_shape
[
0
].
dims
input_shape
=
input_op
.
output_shape
[
0
].
dims
...
@@ -1332,8 +1339,13 @@ class Transformer(base_converter.ConverterInterface):
...
@@ -1332,8 +1339,13 @@ class Transformer(base_converter.ConverterInterface):
is_fc
=
False
is_fc
=
False
else
:
else
:
weight
=
self
.
_consts
[
matmul_op
.
input
[
1
]]
weight
=
self
.
_consts
[
matmul_op
.
input
[
1
]]
if
len
(
weight
.
dims
)
!=
2
or
\
od
=
op
.
output_shape
[
0
].
dims
weight
.
dims
[
0
]
!=
op
.
output_shape
[
0
].
dims
[
1
]:
wd
=
weight
.
dims
if
len
(
wd
)
!=
2
:
is_fc
=
False
# tf fc weight: IO; onnx/pytorch fc weight: OI
if
(
is_tf
and
wd
[
0
]
!=
od
[
1
])
or
\
((
is_torch
or
is_onnx
)
and
wd
[
1
]
!=
od
[
1
]):
is_fc
=
False
is_fc
=
False
if
is_fc
:
if
is_fc
:
print
(
'convert reshape and matmul to fc'
)
print
(
'convert reshape and matmul to fc'
)
...
@@ -1344,24 +1356,40 @@ class Transformer(base_converter.ConverterInterface):
...
@@ -1344,24 +1356,40 @@ class Transformer(base_converter.ConverterInterface):
matmul_op
.
type
=
MaceOp
.
FullyConnected
.
name
matmul_op
.
type
=
MaceOp
.
FullyConnected
.
name
weight_data
=
np
.
array
(
weight
.
float_data
).
reshape
(
weight_data
=
np
.
array
(
weight
.
float_data
).
reshape
(
weight
.
dims
)
weight
.
dims
)
weight
.
dims
[:]
=
input_shape
[
1
:]
+
\
if
is_tf
:
[
weight_data
.
shape
[
1
]]
weight
.
dims
[:]
=
input_shape
[
1
:]
+
\
[
weight_data
.
shape
[
1
]]
if
is_torch
or
is_onnx
:
in_data_format
=
ConverterUtil
.
data_format
(
input_op
)
# OI+NCHW[2:]=OIHW
if
in_data_format
==
DataFormat
.
NCHW
:
weight
.
dims
.
extend
(
input_shape
[
2
:])
# OI+NHWC[1:3]=OIHW
else
:
weight
.
dims
.
extend
(
input_shape
[
1
:
3
])
return
True
return
True
# transform `fc1(2D) -> matmul` to `fc1(2D) -> fc1(2D)`
# transform `fc1(2D) -> matmul` to `fc1(2D) -> fc1(2D)`
if
op
.
type
==
MaceOp
.
MatMul
.
name
and
\
if
op
.
type
==
MaceOp
.
MatMul
.
name
and
\
filter_format
==
DataFormat
.
HWIO
and
\
(
is_tf
or
is_torch
or
is_onnx
)
and
\
op
.
input
[
1
]
in
self
.
_consts
:
op
.
input
[
1
]
in
self
.
_consts
:
producer
=
self
.
_producer
[
op
.
input
[
0
]]
producer
=
self
.
_producer
[
op
.
input
[
0
]]
weight
=
self
.
_consts
[
op
.
input
[
1
]]
weight
=
self
.
_consts
[
op
.
input
[
1
]]
if
len
(
weight
.
dims
)
==
2
and
self
.
is_after_fc
(
op
)
and
\
if
len
(
weight
.
dims
)
==
2
and
self
.
is_after_fc
(
op
)
and
\
len
(
producer
.
output_shape
[
0
].
dims
)
==
2
and
\
len
(
producer
.
output_shape
[
0
].
dims
)
==
2
and
\
weight
.
dims
[
0
]
==
producer
.
output_shape
[
0
].
dims
[
1
]:
((
is_tf
and
weight
.
dims
[
0
]
==
producer
.
output_shape
[
0
].
dims
[
1
])
or
# noqa
(
is_torch
and
weight
.
dims
[
1
]
==
producer
.
output_shape
[
0
].
dims
[
1
])
or
# noqa
(
is_onnx
and
weight
.
dims
[
1
]
==
producer
.
output_shape
[
0
].
dims
[
1
])):
# noqa
six
.
print_
(
'convert matmul to fc'
)
six
.
print_
(
'convert matmul to fc'
)
op
.
type
=
MaceOp
.
FullyConnected
.
name
op
.
type
=
MaceOp
.
FullyConnected
.
name
weight_data
=
np
.
array
(
weight
.
float_data
).
reshape
(
weight_data
=
np
.
array
(
weight
.
float_data
).
reshape
(
weight
.
dims
)
weight
.
dims
)
weight
.
dims
[:]
=
[
1
,
1
]
+
list
(
weight_data
.
shape
)
# only 1 of the 2 branches can be executed
if
is_tf
:
weight
.
dims
[:]
=
[
1
,
1
]
+
list
(
weight_data
.
shape
)
if
is_torch
or
is_onnx
:
weight
.
dims
.
extend
([
1
,
1
])
return
True
return
True
if
self
.
_option
.
device
==
DeviceType
.
APU
.
value
:
if
self
.
_option
.
device
==
DeviceType
.
APU
.
value
:
...
@@ -2257,7 +2285,7 @@ class Transformer(base_converter.ConverterInterface):
...
@@ -2257,7 +2285,7 @@ class Transformer(base_converter.ConverterInterface):
dim_arg
=
ConverterUtil
.
get_arg
(
op
,
MaceKeyword
.
mace_dim_str
)
dim_arg
=
ConverterUtil
.
get_arg
(
op
,
MaceKeyword
.
mace_dim_str
)
shape_tensor
=
None
shape_tensor
=
None
if
len
(
op
.
input
)
==
1
:
if
len
(
op
.
input
)
==
1
:
print
(
"Transform Caffe Reshape"
)
print
(
"Transform Caffe
or PyTorch
Reshape"
)
dims
=
[]
dims
=
[]
axis_arg
=
ConverterUtil
.
get_arg
(
op
,
MaceKeyword
.
mace_axis_str
)
axis_arg
=
ConverterUtil
.
get_arg
(
op
,
MaceKeyword
.
mace_axis_str
)
# transform caffe reshape op
# transform caffe reshape op
...
...
tools/python/utils/config_parser.py
浏览文件 @
9efe5dc5
...
@@ -151,6 +151,7 @@ class Platform(Enum):
...
@@ -151,6 +151,7 @@ class Platform(Enum):
CAFFE
=
1
CAFFE
=
1
ONNX
=
2
ONNX
=
2
MEGENGINE
=
3
MEGENGINE
=
3
PYTORCH
=
4
def
parse_platform
(
str
):
def
parse_platform
(
str
):
...
...
tools/python/utils/device.py
浏览文件 @
9efe5dc5
...
@@ -51,8 +51,8 @@ def execute(cmd, verbose=True):
...
@@ -51,8 +51,8 @@ def execute(cmd, verbose=True):
print
(
line
)
print
(
line
)
buf
.
append
(
line
)
buf
.
append
(
line
)
for
l
in
p
.
stdout
:
for
l
i
in
p
.
stdout
:
line
=
l
.
strip
()
line
=
l
i
.
strip
()
if
verbose
:
if
verbose
:
print
(
line
)
print
(
line
)
buf
.
append
(
line
)
buf
.
append
(
line
)
...
...
tools/python/validate.py
浏览文件 @
9efe5dc5
...
@@ -13,6 +13,7 @@
...
@@ -13,6 +13,7 @@
# limitations under the License.
# limitations under the License.
import
os
import
os
import
sys
import
os.path
import
os.path
import
numpy
as
np
import
numpy
as
np
import
six
import
six
...
@@ -204,6 +205,48 @@ def validate_tf_model(model_file,
...
@@ -204,6 +205,48 @@ def validate_tf_model(model_file,
validation_threshold
,
log_file
)
validation_threshold
,
log_file
)
def
validate_pytorch_model
(
model_file
,
input_file
,
mace_out_file
,
input_names
,
input_shapes
,
input_data_formats
,
output_names
,
output_shapes
,
output_data_formats
,
validation_threshold
,
input_data_types
,
log_file
):
import
torch
loaded_model
=
torch
.
jit
.
load
(
model_file
)
pytorch_inputs
=
[]
for
i
in
range
(
len
(
input_names
)):
input_value
=
load_data
(
util
.
formatted_file_name
(
input_file
,
input_names
[
i
]),
input_data_types
[
i
])
input_value
=
input_value
.
reshape
(
input_shapes
[
i
])
if
input_data_formats
[
i
]
==
DataFormat
.
NHWC
and
\
len
(
input_shapes
[
i
])
==
4
:
input_value
=
input_value
.
transpose
((
0
,
3
,
1
,
2
))
input_value
=
torch
.
from_numpy
(
input_value
)
pytorch_inputs
.
append
(
input_value
)
with
torch
.
no_grad
():
pytorch_outputs
=
loaded_model
(
*
pytorch_inputs
)
if
isinstance
(
pytorch_outputs
,
torch
.
Tensor
):
pytorch_outputs
=
[
pytorch_outputs
]
else
:
if
not
isinstance
(
pytorch_outputs
,
(
list
,
tuple
)):
print
(
'return type {} unsupported'
.
format
(
type
(
pytorch_outputs
)))
sys
.
exit
(
1
)
for
i
in
range
(
len
(
output_names
)):
value
=
pytorch_outputs
[
i
].
numpy
()
output_file_name
=
util
.
formatted_file_name
(
mace_out_file
,
output_names
[
i
])
mace_out_value
=
load_data
(
output_file_name
)
# MACE: always returns tensor of dim 1
# pytorch: NCHW, conversion is needed
if
output_data_formats
[
i
]
==
DataFormat
.
NHWC
and
\
len
(
output_shapes
[
i
])
==
4
:
mace_out_value
=
mace_out_value
.
reshape
(
output_shapes
[
i
])
\
.
transpose
((
0
,
3
,
1
,
2
))
compare_output
(
output_names
[
i
],
mace_out_value
,
value
,
validation_threshold
,
log_file
)
def
validate_caffe_model
(
model_file
,
input_file
,
def
validate_caffe_model
(
model_file
,
input_file
,
mace_out_file
,
weight_file
,
mace_out_file
,
weight_file
,
input_names
,
input_shapes
,
input_data_formats
,
input_names
,
input_shapes
,
input_data_formats
,
...
@@ -387,6 +430,12 @@ def validate(platform, model_file, weight_file, input_file, mace_out_file,
...
@@ -387,6 +430,12 @@ def validate(platform, model_file, weight_file, input_file, mace_out_file,
output_node
,
output_shape
,
output_data_format
,
output_node
,
output_shape
,
output_data_format
,
validation_threshold
,
input_data_type
,
validation_threshold
,
input_data_type
,
log_file
)
log_file
)
elif
platform
==
Platform
.
PYTORCH
:
validate_pytorch_model
(
model_file
,
input_file
,
mace_out_file
,
input_node
,
input_shape
,
input_data_format
,
output_node
,
output_shape
,
output_data_format
,
validation_threshold
,
input_data_type
,
log_file
)
elif
platform
==
Platform
.
CAFFE
:
elif
platform
==
Platform
.
CAFFE
:
validate_caffe_model
(
model_file
,
validate_caffe_model
(
model_file
,
input_file
,
mace_out_file
,
weight_file
,
input_file
,
mace_out_file
,
weight_file
,
...
...
tools/sh_commands.py
浏览文件 @
9efe5dc5
...
@@ -53,7 +53,8 @@ def strip_invalid_utf8(str):
...
@@ -53,7 +53,8 @@ def strip_invalid_utf8(str):
def
split_stdout
(
stdout_str
):
def
split_stdout
(
stdout_str
):
stdout_str
=
strip_invalid_utf8
(
stdout_str
)
stdout_str
=
strip_invalid_utf8
(
stdout_str
)
# Filter out last empty line
# Filter out last empty line
return
[
l
.
strip
()
for
l
in
stdout_str
.
split
(
'
\n
'
)
if
len
(
l
.
strip
())
>
0
]
return
[
line
.
strip
()
for
line
in
stdout_str
.
split
(
'
\n
'
)
if
len
(
line
.
strip
())
>
0
]
def
make_output_processor
(
buff
):
def
make_output_processor
(
buff
):
...
@@ -659,7 +660,7 @@ def validate_model(abi,
...
@@ -659,7 +660,7 @@ def validate_model(abi,
sh
.
rm
(
"-rf"
,
"%s/%s"
%
(
model_output_dir
,
formatted_name
))
sh
.
rm
(
"-rf"
,
"%s/%s"
%
(
model_output_dir
,
formatted_name
))
device
.
pull_from_data_dir
(
formatted_name
,
model_output_dir
)
device
.
pull_from_data_dir
(
formatted_name
,
model_output_dir
)
if
platform
==
"tensorflow"
or
platform
==
"onnx"
:
if
platform
==
"tensorflow"
or
platform
==
"onnx"
or
platform
==
"pytorch"
:
validate
(
platform
,
model_file_path
,
""
,
validate
(
platform
,
model_file_path
,
""
,
"%s/%s"
%
(
model_output_dir
,
input_file_name
),
"%s/%s"
%
(
model_output_dir
,
input_file_name
),
"%s/%s"
%
(
model_output_dir
,
output_file_name
),
device_type
,
"%s/%s"
%
(
model_output_dir
,
output_file_name
),
device_type
,
...
...
tools/validate.py
浏览文件 @
9efe5dc5
...
@@ -216,6 +216,48 @@ def validate_tf_model(platform, device_type, model_file,
...
@@ -216,6 +216,48 @@ def validate_tf_model(platform, device_type, model_file,
validation_threshold
,
log_file
)
validation_threshold
,
log_file
)
def
validate_pytorch_model
(
platform
,
device_type
,
model_file
,
input_file
,
mace_out_file
,
input_names
,
input_shapes
,
input_data_formats
,
output_names
,
output_shapes
,
output_data_formats
,
validation_threshold
,
input_data_types
,
log_file
):
import
torch
loaded_model
=
torch
.
jit
.
load
(
model_file
)
pytorch_inputs
=
[]
for
i
in
range
(
len
(
input_names
)):
input_value
=
load_data
(
common
.
formatted_file_name
(
input_file
,
input_names
[
i
]),
input_data_types
[
i
])
input_value
=
input_value
.
reshape
(
input_shapes
[
i
])
if
input_data_formats
[
i
]
==
common
.
DataFormat
.
NHWC
and
\
len
(
input_shapes
[
i
])
==
4
:
input_value
=
input_value
.
transpose
((
0
,
3
,
1
,
2
))
input_value
=
torch
.
from_numpy
(
input_value
)
pytorch_inputs
.
append
(
input_value
)
with
torch
.
no_grad
():
pytorch_outputs
=
loaded_model
(
*
pytorch_inputs
)
if
isinstance
(
pytorch_outputs
,
torch
.
Tensor
):
pytorch_outputs
=
[
pytorch_outputs
]
else
:
if
not
isinstance
(
pytorch_outputs
,
(
list
,
tuple
)):
print
(
'return type {} unsupported yet'
.
format
(
type
(
pytorch_outputs
)))
sys
.
exit
(
1
)
for
i
in
range
(
len
(
output_names
)):
value
=
pytorch_outputs
[
i
].
numpy
()
output_file_name
=
common
.
formatted_file_name
(
mace_out_file
,
output_names
[
i
])
mace_out_value
=
load_data
(
output_file_name
)
# MACE: NHWC, pytorch: NCHW, conversion is needed
if
output_data_formats
[
i
]
==
common
.
DataFormat
.
NHWC
and
\
len
(
output_shapes
[
i
])
==
4
:
mace_out_value
=
mace_out_value
.
reshape
(
output_shapes
[
i
])
\
.
transpose
((
0
,
3
,
1
,
2
))
compare_output
(
platform
,
device_type
,
output_names
[
i
],
mace_out_value
,
value
,
validation_threshold
,
log_file
)
def
validate_caffe_model
(
platform
,
device_type
,
model_file
,
input_file
,
def
validate_caffe_model
(
platform
,
device_type
,
model_file
,
input_file
,
mace_out_file
,
weight_file
,
mace_out_file
,
weight_file
,
input_names
,
input_shapes
,
input_data_formats
,
input_names
,
input_shapes
,
input_data_formats
,
...
@@ -418,6 +460,13 @@ def validate(platform, model_file, weight_file, input_file, mace_out_file,
...
@@ -418,6 +460,13 @@ def validate(platform, model_file, weight_file, input_file, mace_out_file,
output_names
,
output_shapes
,
output_data_formats
,
output_names
,
output_shapes
,
output_data_formats
,
validation_threshold
,
input_data_types
,
validation_threshold
,
input_data_types
,
log_file
)
log_file
)
elif
platform
==
'pytorch'
:
validate_pytorch_model
(
platform
,
device_type
,
model_file
,
input_file
,
mace_out_file
,
input_names
,
input_shapes
,
input_data_formats
,
output_names
,
output_shapes
,
output_data_formats
,
validation_threshold
,
input_data_types
,
log_file
)
elif
platform
==
'caffe'
:
elif
platform
==
'caffe'
:
validate_caffe_model
(
platform
,
device_type
,
model_file
,
validate_caffe_model
(
platform
,
device_type
,
model_file
,
input_file
,
mace_out_file
,
weight_file
,
input_file
,
mace_out_file
,
weight_file
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录