Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
慢慢CG
Mace
提交
cdd9c993
Mace
项目概览
慢慢CG
/
Mace
与 Fork 源项目一致
Fork自
Xiaomi / Mace
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
Mace
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
cdd9c993
编写于
11月 03, 2020
作者:
Z
Zhang Zhimin
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat: Support Keras tc-resnet model convert
上级
6c9f380e
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
274 addition
and
64 deletion
+274
-64
.gitlab-ci.yml
.gitlab-ci.yml
+15
-6
micro/.gitlab-ci.yml
micro/.gitlab-ci.yml
+0
-32
micro/pretrained_models/har-cnn/har-cnn-bf16.yml
micro/pretrained_models/har-cnn/har-cnn-bf16.yml
+24
-0
micro/pretrained_models/har-cnn/har-cnn.yml
micro/pretrained_models/har-cnn/har-cnn.yml
+24
-0
micro/pretrained_models/keras/mnist/mnist-int8.yml
micro/pretrained_models/keras/mnist/mnist-int8.yml
+2
-0
micro/tools/ci/cross_build.sh
micro/tools/ci/cross_build.sh
+0
-2
micro/tools/ci/host_build_and_run_tests.sh
micro/tools/ci/host_build_and_run_tests.sh
+0
-2
micro/tools/ci/model_convert.sh
micro/tools/ci/model_convert.sh
+10
-16
tools/python/transform/keras_converter.py
tools/python/transform/keras_converter.py
+143
-6
tools/python/validate.py
tools/python/validate.py
+56
-0
未找到文件。
.gitlab-ci.yml
浏览文件 @
cdd9c993
...
...
@@ -191,9 +191,18 @@ dynamic_linking_test:
only
:
-
triggers
micro-child
:
stage
:
build
trigger
:
include
:
-
'
micro/.gitlab-ci.yml'
strategy
:
depend
micro
:
stage
:
test
tags
:
-
mace-micro
image
:
mace-micro-dev
before_script
:
-
git submodule deinit -f .
-
git submodule sync
-
git submodule update --init .
script
:
-
bash micro/tools/ci/model_convert.sh
-
bash micro/tools/ci/cross_build.sh
-
bash micro/tools/ci/host_build_and_run_examples.sh
-
bash micro/tools/ci/host_build_and_run_tests.sh
-
bash micro/tools/ci/build_mbed_example.sh
\ No newline at end of file
micro/.gitlab-ci.yml
已删除
100644 → 0
浏览文件 @
6c9f380e
default
:
tags
:
-
mace-micro
image
:
mace-micro-dev
before_script
:
-
git submodule deinit -f .
-
git submodule sync
-
git submodule update --init .
stages
:
-
convert
-
build
-
test
model-convert
:
stage
:
convert
script
:
-
bash micro/tools/ci/model_convert.sh
artifacts
:
paths
:
-
mace-models
untracked
:
true
cross-build
:
stage
:
build
script
:
-
bash micro/tools/ci/cross_build.sh
-
bash micro/tools/ci/host_build_and_run_examples.sh
-
bash micro/tools/ci/host_build_and_run_tests.sh
-
bash micro/tools/ci/build_mbed_example.sh
micro/pretrained_models/har-cnn/har-cnn-bf16.yml
0 → 100644
浏览文件 @
cdd9c993
library_name
:
har-cnn
target_abis
:
[
host
]
model_graph_format
:
file
model_data_format
:
file
models
:
har_cnn
:
platform
:
tensorflow
model_file_path
:
http://cnbj1.fds.api.xiaomi.com/mace/miai-models/micro/har-cnn/har-cnn.pb
model_sha256_checksum
:
93451bdf0590842ae80e9de72a22ce3b1faee3e0d9cf7b8e2d60421e885ed6e7
subgraphs
:
-
input_tensors
:
-
conv1d/conv1d/ExpandDims
input_shapes
:
-
1,1,128,9
output_tensors
:
-
dense/BiasAdd
output_shapes
:
-
1,6
runtime
:
cpu
data_type
:
bf16_fp32
limit_opencl_kernel_time
:
0
nnlib_graph_mode
:
0
obfuscate
:
0
winograd
:
0
micro/pretrained_models/har-cnn/har-cnn.yml
0 → 100644
浏览文件 @
cdd9c993
library_name
:
har-cnn
target_abis
:
[
host
]
model_graph_format
:
file
model_data_format
:
file
models
:
har_cnn
:
platform
:
tensorflow
model_file_path
:
http://cnbj1.fds.api.xiaomi.com/mace/miai-models/micro/har-cnn/har-cnn.pb
model_sha256_checksum
:
93451bdf0590842ae80e9de72a22ce3b1faee3e0d9cf7b8e2d60421e885ed6e7
subgraphs
:
-
input_tensors
:
-
conv1d/conv1d/ExpandDims
input_shapes
:
-
1,1,128,9
output_tensors
:
-
dense/BiasAdd
output_shapes
:
-
1,6
runtime
:
cpu
data_type
:
fp32_fp32
limit_opencl_kernel_time
:
0
nnlib_graph_mode
:
0
obfuscate
:
0
winograd
:
0
micro/pretrained_models/keras/mnist/mnist-int8.yml
浏览文件 @
cdd9c993
...
...
@@ -18,6 +18,8 @@ models:
-
quant_dense_1/Softmax:0
output_shapes
:
-
1,10
validation_inputs_data
:
-
https://cnbj1.fds.api.xiaomi.com/mace/inputs/mnist4.npy
runtime
:
cpu
limit_opencl_kernel_time
:
0
nnlib_graph_mode
:
0
...
...
micro/tools/ci/cross_build.sh
浏览文件 @
cdd9c993
#! /bin/bash
git submodule update
--init
.
echo
"Builds host float32"
rm
-rf
build/micro
./micro/tools/cmake/cmake-build-host.sh
\
...
...
micro/tools/ci/host_build_and_run_tests.sh
浏览文件 @
cdd9c993
#! /bin/bash
git submodule update
--init
.
rm
-rf
build/micro
./micro/tools/cmake/cmake-build-host.sh
\
-DMACE_MICRO_ENABLE_TESTS
=
ON
\
...
...
micro/tools/ci/model_convert.sh
浏览文件 @
cdd9c993
#! /bin/bash
rm
-rf
mace-models
rm
-rf
build/micro
GIT_SSH_COMMAND
=
"ssh -o UserKnownHostsFile=/dev/null -o StrictHostKeyChecking=no"
git clone git@git.n.xiaomi.com:applied-machine-learning/sysml/mace-models.git
||
exit
-1
git submodule update
--init
.
||
exit
-1
CONF_FILE
=
mace-models/micro-models/har-cnn/har-cnn.yml
CONF_FILE
=
micro/pretrained_models/har-cnn/har-cnn.yml
python tools/python/convert.py
--config
=
${
CONF_FILE
}
--enable_micro
||
exit
-1
python tools/python/run_micro.py
--config
$CONF_FILE
--build
--validate
--model_name
har_cnn
||
exit
-1
python tools/python/run_micro.py
--config
$CONF_FILE
--model_name
har_cnn
--build
--benchmark
||
exit
-1
CONF_FILE
=
m
ace-models/micro-
models/har-cnn/har-cnn-bf16.yml
CONF_FILE
=
m
icro/pretrained_
models/har-cnn/har-cnn-bf16.yml
python tools/python/convert.py
--config
=
${
CONF_FILE
}
--enable_micro
||
exit
-1
python tools/python/run_micro.py
--config
$CONF_FILE
--build
--validate
--model_name
har_cnn
||
exit
-1
CONF_FILE
=
m
ace-models/micro-
models/keras/mnist/mnist.yml
CONF_FILE
=
m
icro/pretrained_
models/keras/mnist/mnist.yml
python3 tools/python/convert.py
--config
=
${
CONF_FILE
}
--enable_micro
||
exit
-1
python3 tools/python/run_micro.py
--config
$CONF_FILE
--build
--validate
--model_name
mnist
||
exit
-1
CONF_FILE
=
m
ace-models/micro-
models/keras/mnist/mnist-int8.yml
CONF_FILE
=
m
icro/pretrained_
models/keras/mnist/mnist-int8.yml
python3 tools/python/convert.py
--config
=
${
CONF_FILE
}
--enable_micro
||
exit
-1
python3 tools/python/run_micro.py
--config
$CONF_FILE
--build
--validate
--model_name
mnist_int8
||
exit
-1
CONF_FILE
=
m
ace-models/micro-
models/keras/har/har.yml
CONF_FILE
=
m
icro/pretrained_
models/keras/har/har.yml
python3 tools/python/convert.py
--config
=
${
CONF_FILE
}
--enable_micro
||
exit
-1
python3 tools/python/run_micro.py
--config
$CONF_FILE
--build
--validate
--model_name
har
||
exit
-1
CONF_FILE
=
mace-models/micro-
models/keras/har/har-int8.yml
python3 tools/python/convert.py
--config
=
${
CONF_FILE
}
--enable_micro
||
exit
-1
python3 tools/python/run_micro.py
--config
$CONF_FILE
--build
--validate
--model_name
har_int8
||
exit
-1
# CONF_FILE=micro/pretrained_
models/keras/har/har-int8.yml
#
python3 tools/python/convert.py --config=${CONF_FILE} --enable_micro || exit -1
#
python3 tools/python/run_micro.py --config $CONF_FILE --build --validate --model_name har_int8 || exit -1
CONF_FILE
=
m
ace-models/micro-
models/tensorflow/kws/kws-tc_resnet8.yml
CONF_FILE
=
m
icro/pretrained_
models/tensorflow/kws/kws-tc_resnet8.yml
python tools/python/convert.py
--config
=
${
CONF_FILE
}
--enable_micro
||
exit
-1
python tools/python/run_micro.py
--config
$CONF_FILE
--build
--validate
--model_name
kws_tc_resnet8
||
exit
-1
CONF_FILE
=
m
ace-models/micro-
models/tensorflow/kws/kws-tc_resnet8-bf16.yml
CONF_FILE
=
m
icro/pretrained_
models/tensorflow/kws/kws-tc_resnet8-bf16.yml
python tools/python/convert.py
--config
=
${
CONF_FILE
}
--enable_micro
||
exit
-1
python tools/python/run_micro.py
--config
$CONF_FILE
--build
--validate
--model_name
kws_tc_resnet8_bf16
||
exit
-1
rm
-rf
mace-models
tools/python/transform/keras_converter.py
浏览文件 @
cdd9c993
...
...
@@ -20,6 +20,8 @@ from transform.base_converter import ReduceType
from
transform.base_converter
import
RoundMode
from
tensorflow
import
keras
from
tensorflow.python.keras.layers
import
convolutional
from
tensorflow.python.keras
import
activations
from
quantize
import
quantize_util
from
utils.util
import
mace_check
...
...
@@ -32,6 +34,8 @@ from tensorflow_model_optimization.python.core.\
from
tensorflow_model_optimization
.
python
.
core
.
\
quantization
.
keras
.
quantize_annotate
import
QuantizeAnnotate
import
numpy
as
np
padding_mode
=
{
"valid"
:
PaddingMode
.
VALID
,
"same"
:
PaddingMode
.
SAME
...
...
@@ -74,7 +78,7 @@ def get_output(keras_op):
return
keras_op
.
output
activation_type
=
{
activation_type
s_dict
=
{
"relu"
:
ActivationType
.
RELU
,
# 'relu6': ActivationType.RELUX,
# 'PReLU': ActivationType.PRELU,
...
...
@@ -89,6 +93,7 @@ class KerasConverter(base_converter.ConverterInterface):
def
__init__
(
self
,
option
,
src_model_file
):
self
.
_op_converters
=
{
keras
.
layers
.
InputLayer
:
self
.
convert_input_layer
,
keras
.
layers
.
Flatten
:
self
.
convert_flatten
,
keras
.
layers
.
Dense
:
self
.
convert_dense
,
keras
.
layers
.
Conv2D
:
self
.
convert_conv2d
,
...
...
@@ -96,6 +101,11 @@ class KerasConverter(base_converter.ConverterInterface):
keras
.
layers
.
Dropout
:
self
.
convert_dropout
,
keras
.
layers
.
DepthwiseConv2D
:
self
.
convert_depthwise_conv2d
,
keras
.
layers
.
Softmax
:
self
.
convert_softmax
,
keras
.
layers
.
BatchNormalization
:
self
.
convert_batch_normalization
,
keras
.
layers
.
Activation
:
self
.
convert_activation
,
keras
.
layers
.
GlobalAveragePooling2D
:
self
.
convert_global_average_pooling2d
,
keras
.
layers
.
Add
:
self
.
convert_add
,
QuantizeLayer
:
self
.
convert_quantize_layer
,
QuantizeWrapper
:
self
.
convert_quantize_wrapper
,
}
...
...
@@ -106,7 +116,8 @@ class KerasConverter(base_converter.ConverterInterface):
ConverterUtil
.
add_data_format_arg
(
self
.
_mace_net_def
,
DataFormat
.
NHWC
)
with
tfmot
.
quantization
.
keras
.
quantize_scope
():
self
.
_keras_model
=
keras
.
models
.
load_model
(
src_model_file
)
self
.
_keras_model
=
keras
.
models
.
load_model
(
src_model_file
,
compile
=
False
)
def
run
(
self
):
for
op
in
self
.
_keras_model
.
layers
:
...
...
@@ -141,10 +152,24 @@ class KerasConverter(base_converter.ConverterInterface):
framework_type_arg
.
i
=
FrameworkType
.
KERAS
.
value
ConverterUtil
.
add_data_format_arg
(
op
,
DataFormat
.
NHWC
)
op
.
input
.
append
(
get_input
(
keras_op
).
name
)
op
.
output
.
append
(
get_output
(
keras_op
).
name
)
input
=
get_input
(
keras_op
)
if
isinstance
(
input
,
list
):
for
e
in
input
:
op
.
input
.
append
(
e
.
name
)
else
:
op
.
input
.
append
(
input
.
name
)
output
=
get_output
(
keras_op
)
mace_check
(
not
isinstance
(
output
,
list
),
"only support one output"
)
op
.
output
.
append
(
output
.
name
)
output_shape
=
op
.
output_shape
.
add
()
output_shape
.
dims
.
extend
(
keras_shape2list
(
get_output
(
keras_op
).
shape
))
output_shape
.
dims
.
extend
(
keras_shape2list
(
output
.
shape
))
return
op
def
convert_input_layer
(
self
,
keras_op
):
op
=
self
.
convert_general_op_with_input_output
(
keras_op
)
op
.
type
=
MaceOp
.
Identity
.
name
return
op
...
...
@@ -268,6 +293,100 @@ class KerasConverter(base_converter.ConverterInterface):
return
op
def
convert_batch_normalization
(
self
,
keras_op
):
op
=
self
.
convert_general_op_with_input_output
(
keras_op
)
op
.
type
=
MaceOp
.
BatchNorm
.
name
gamma
=
keras_op
.
gamma
.
numpy
()
beta
=
keras_op
.
beta
.
numpy
()
mean
=
keras_op
.
moving_mean
.
numpy
()
variance
=
keras_op
.
moving_variance
.
numpy
()
epsilon
=
keras_op
.
epsilon
scale
=
(
1.0
/
np
.
sqrt
(
variance
+
epsilon
))
*
gamma
offset
=
(
-
mean
*
scale
)
+
beta
scale_name
=
keras_op
.
name
+
'/scale:0'
offset_name
=
keras_op
.
name
+
'/offset:0'
self
.
add_numpy_tensor
(
scale_name
,
scale
)
self
.
add_numpy_tensor
(
offset_name
,
offset
)
op
.
input
.
extend
([
scale_name
,
offset_name
])
return
op
def
convert_global_average_pooling2d
(
self
,
keras_op
):
op
=
self
.
convert_general_op_with_input_output
(
keras_op
)
op
.
type
=
MaceOp
.
Reduce
.
name
reduce_type_arg
=
op
.
arg
.
add
()
reduce_type_arg
.
name
=
MaceKeyword
.
mace_reduce_type_str
reduce_type_arg
.
i
=
ReduceType
.
MEAN
.
value
axis_arg
=
op
.
arg
.
add
()
axis_arg
.
name
=
MaceKeyword
.
mace_axis_str
axis_arg
.
ints
.
extend
([
1
,
2
])
keep_dims_arg
=
op
.
arg
.
add
()
keep_dims_arg
.
name
=
MaceKeyword
.
mace_keepdims_str
keep_dims_arg
.
i
=
1
origin_output_shape
=
copy
.
deepcopy
(
op
.
output_shape
[
0
].
dims
)
op
.
output_shape
[
0
].
dims
.
insert
(
1
,
1
)
op
.
output_shape
[
0
].
dims
.
insert
(
1
,
1
)
output_name
=
op
.
output
[
0
]
del
op
.
output
[:]
output_name_mid
=
output_name
+
"_mid_reshape"
op
.
output
.
append
(
output_name_mid
)
op_reshape
=
self
.
_mace_net_def
.
op
.
add
()
op_reshape
.
name
=
keras_op
.
name
+
"_reshape"
op_reshape
.
type
=
MaceOp
.
Reshape
.
name
op_reshape
.
input
.
append
(
output_name_mid
)
op_reshape
.
output
.
append
(
output_name
)
output_shape
=
op_reshape
.
output_shape
.
add
()
output_shape
.
dims
.
extend
(
origin_output_shape
)
t_shape
=
list
(
origin_output_shape
)
shape_tensor_name
=
op_reshape
.
name
+
"_dest_shape"
self
.
add_tensor
(
shape_tensor_name
,
[
len
(
t_shape
)],
mace_pb2
.
DT_INT32
,
t_shape
)
op_reshape
.
input
.
append
(
shape_tensor_name
)
data_type_arg
=
op_reshape
.
arg
.
add
()
data_type_arg
.
name
=
"T"
data_type_arg
.
i
=
dtype2mtype
(
keras_op
.
dtype
)
framework_type_arg
=
op_reshape
.
arg
.
add
()
framework_type_arg
.
name
=
MaceKeyword
.
mace_framework_type_str
framework_type_arg
.
i
=
FrameworkType
.
KERAS
.
value
ConverterUtil
.
add_data_format_arg
(
op_reshape
,
DataFormat
.
NHWC
)
return
op_reshape
def
convert_activation
(
self
,
keras_op
):
op
=
self
.
convert_general_op_with_input_output
(
keras_op
)
activation
=
keras_op
.
activation
if
activation
==
activations
.
linear
:
op
.
type
=
MaceOp
.
Identity
.
name
elif
activation
is
activations
.
relu
:
op
.
type
=
MaceOp
.
Activation
.
name
type_arg
=
op
.
arg
.
add
()
type_arg
.
name
=
MaceKeyword
.
mace_activation_type_str
type_arg
.
s
=
six
.
b
(
"RELU"
)
elif
activation
==
activations
.
softmax
:
op
.
type
=
MaceOp
.
Softmax
.
name
else
:
mace_check
(
False
,
"Unsupported activation"
)
return
op
def
convert_add
(
self
,
keras_op
):
op
=
self
.
convert_general_op_with_input_output
(
keras_op
)
op
.
type
=
MaceOp
.
Eltwise
.
name
type_arg
=
op
.
arg
.
add
()
type_arg
.
name
=
MaceKeyword
.
mace_element_type_str
type_arg
.
i
=
EltwiseType
.
SUM
.
value
return
op
def
convert_quantize_layer
(
self
,
keras_op
):
op
=
self
.
_mace_net_def
.
op
.
add
()
op
.
name
=
keras_op
.
name
...
...
@@ -328,6 +447,24 @@ class KerasConverter(base_converter.ConverterInterface):
tensor
.
float_data
.
extend
(
keras_tensor
.
numpy
().
flat
)
return
tensor
def
add_numpy_tensor
(
self
,
name
,
np_tensor
):
tensor
=
self
.
_mace_net_def
.
tensors
.
add
()
tensor
.
name
=
name
tensor
.
dims
.
extend
(
np_tensor
.
shape
)
tensor
.
data_type
=
dtype2mtype
(
np_tensor
.
dtype
)
tensor
.
float_data
.
extend
(
np_tensor
.
flat
)
return
tensor
def
add_tensor
(
self
,
name
,
shape
,
data_type
,
value
):
tensor
=
self
.
_mace_net_def
.
tensors
.
add
()
tensor
.
name
=
name
tensor
.
dims
.
extend
(
list
(
shape
))
tensor
.
data_type
=
data_type
if
data_type
==
mace_pb2
.
DT_INT32
:
tensor
.
int32_data
.
extend
(
value
)
else
:
tensor
.
float_data
.
extend
(
value
)
def
split_activation_op
(
self
,
keras_op
,
op
):
activation
=
keras_op
.
get_config
()[
"activation"
]
if
"class_name"
in
activation
:
...
...
@@ -358,7 +495,7 @@ class KerasConverter(base_converter.ConverterInterface):
activation_op
.
type
=
MaceOp
.
Activation
.
name
type_arg
=
activation_op
.
arg
.
add
()
type_arg
.
name
=
MaceKeyword
.
mace_activation_type_str
type_arg
.
s
=
six
.
b
(
activation_type
[
activation
].
name
)
type_arg
.
s
=
six
.
b
(
activation_type
s_dict
[
activation
].
name
)
activation_op
.
input
.
append
(
activation_tmp_name
)
activation_op
.
output
.
append
(
get_output
(
keras_op
).
name
)
...
...
tools/python/validate.py
浏览文件 @
cdd9c993
...
...
@@ -408,6 +408,54 @@ def validate_megengine_model(model_file, input_file,
mge_output_value
,
validation_threshold
,
log_file
)
def
validate_keras_model
(
model_file
,
input_file
,
mace_out_file
,
input_names
,
input_shapes
,
input_data_formats
,
output_names
,
output_shapes
,
output_data_formats
,
validation_threshold
,
input_data_types
,
log_file
):
from
tensorflow
import
keras
import
tensorflow_model_optimization
as
tfmot
if
not
os
.
path
.
isfile
(
model_file
):
util
.
MaceLogger
.
error
(
VALIDATION_MODULE
,
"Input model file '"
+
model_file
+
"' does not exist!"
)
with
tfmot
.
quantization
.
keras
.
quantize_scope
():
keras_model
=
keras
.
models
.
load_model
(
model_file
,
compile
=
False
)
input
=
[]
for
i
in
range
(
len
(
input_names
)):
input_value
=
load_data
(
util
.
formatted_file_name
(
input_file
,
input_names
[
i
]),
input_data_types
[
i
])
input_value
=
input_value
.
reshape
(
input_shapes
[
i
])
if
input_data_formats
[
i
]
==
DataFormat
.
NCHW
and
\
len
(
input_shapes
[
i
])
==
4
:
input_value
=
input_value
.
transpose
((
0
,
2
,
3
,
1
))
elif
input_data_formats
[
i
]
==
DataFormat
.
OIHW
and
\
len
(
input_shapes
[
i
])
==
4
:
# OIHW -> HWIO
input_value
=
input_value
.
transpose
((
2
,
3
,
1
,
0
))
input
.
append
(
input_value
)
output_values
=
keras_model
.
predict
(
input
)
for
i
in
range
(
len
(
output_names
)):
output_file_name
=
util
.
formatted_file_name
(
mace_out_file
,
output_names
[
i
])
mace_out_value
=
load_data
(
output_file_name
,
get_data_type_by_value
(
output_values
[
i
]))
if
output_data_formats
[
i
]
==
DataFormat
.
NCHW
and
\
len
(
output_shapes
[
i
])
==
4
:
mace_out_value
=
mace_out_value
.
\
reshape
(
output_shapes
[
i
]).
transpose
((
0
,
2
,
3
,
1
))
compare_output
(
output_names
[
i
],
mace_out_value
,
output_values
[
i
],
validation_threshold
,
log_file
)
def
validate
(
platform
,
model_file
,
weight_file
,
input_file
,
mace_out_file
,
input_shape
,
output_shape
,
input_data_format
,
output_data_format
,
input_node
,
output_node
,
...
...
@@ -458,3 +506,11 @@ def validate(platform, model_file, weight_file, input_file, mace_out_file,
output_data_format
,
validation_threshold
,
input_data_type
,
log_file
)
elif
platform
==
Platform
.
KERAS
:
validate_keras_model
(
model_file
,
input_file
,
mace_out_file
,
input_node
,
input_shape
,
input_data_format
,
output_node
,
output_shape
,
output_data_format
,
validation_threshold
,
input_data_type
,
log_file
)
else
:
mace_check
(
False
,
"Unsupported platform"
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录