提交 65aa9370 编写于 作者: 卢旭辉

Merge branch 'keras_resnet' into 'master'

feat: Support Keras tc-resnet model convert

See merge request applied-machine-learning/sysml/mace!1317
......@@ -191,9 +191,18 @@ dynamic_linking_test:
only:
- triggers
micro-child:
stage: build
trigger:
include:
- 'micro/.gitlab-ci.yml'
strategy: depend
micro:
stage: test
tags:
- mace-micro
image: mace-micro-dev
before_script:
- git submodule deinit -f .
- git submodule sync
- git submodule update --init .
script:
- bash micro/tools/ci/model_convert.sh
- bash micro/tools/ci/cross_build.sh
- bash micro/tools/ci/host_build_and_run_examples.sh
- bash micro/tools/ci/host_build_and_run_tests.sh
- bash micro/tools/ci/build_mbed_example.sh
\ No newline at end of file
default:
tags:
- mace-micro
image: mace-micro-dev
before_script:
- git submodule deinit -f .
- git submodule sync
- git submodule update --init .
stages:
- convert
- build
- test
model-convert:
stage: convert
script:
- bash micro/tools/ci/model_convert.sh
artifacts:
paths:
- mace-models
untracked: true
cross-build:
stage: build
script:
- bash micro/tools/ci/cross_build.sh
- bash micro/tools/ci/host_build_and_run_examples.sh
- bash micro/tools/ci/host_build_and_run_tests.sh
- bash micro/tools/ci/build_mbed_example.sh
library_name: har-cnn
target_abis: [host]
model_graph_format: file
model_data_format: file
models:
har_cnn:
platform: tensorflow
model_file_path: http://cnbj1.fds.api.xiaomi.com/mace/miai-models/micro/har-cnn/har-cnn.pb
model_sha256_checksum: 93451bdf0590842ae80e9de72a22ce3b1faee3e0d9cf7b8e2d60421e885ed6e7
subgraphs:
- input_tensors:
- conv1d/conv1d/ExpandDims
input_shapes:
- 1,1,128,9
output_tensors:
- dense/BiasAdd
output_shapes:
- 1,6
runtime: cpu
data_type: bf16_fp32
limit_opencl_kernel_time: 0
nnlib_graph_mode: 0
obfuscate: 0
winograd: 0
library_name: har-cnn
target_abis: [host]
model_graph_format: file
model_data_format: file
models:
har_cnn:
platform: tensorflow
model_file_path: http://cnbj1.fds.api.xiaomi.com/mace/miai-models/micro/har-cnn/har-cnn.pb
model_sha256_checksum: 93451bdf0590842ae80e9de72a22ce3b1faee3e0d9cf7b8e2d60421e885ed6e7
subgraphs:
- input_tensors:
- conv1d/conv1d/ExpandDims
input_shapes:
- 1,1,128,9
output_tensors:
- dense/BiasAdd
output_shapes:
- 1,6
runtime: cpu
data_type: fp32_fp32
limit_opencl_kernel_time: 0
nnlib_graph_mode: 0
obfuscate: 0
winograd: 0
......@@ -18,6 +18,8 @@ models:
- quant_dense_1/Softmax:0
output_shapes:
- 1,10
validation_inputs_data:
- https://cnbj1.fds.api.xiaomi.com/mace/inputs/mnist4.npy
runtime: cpu
limit_opencl_kernel_time: 0
nnlib_graph_mode: 0
......
#! /bin/bash
git submodule update --init .
echo "Builds host float32"
rm -rf build/micro
./micro/tools/cmake/cmake-build-host.sh \
......
#! /bin/bash
git submodule update --init .
rm -rf build/micro
./micro/tools/cmake/cmake-build-host.sh \
-DMACE_MICRO_ENABLE_TESTS=ON \
......
#! /bin/bash
rm -rf mace-models
rm -rf build/micro
GIT_SSH_COMMAND="ssh -o UserKnownHostsFile=/dev/null -o StrictHostKeyChecking=no" git clone git@git.n.xiaomi.com:applied-machine-learning/sysml/mace-models.git || exit -1
git submodule update --init . || exit -1
CONF_FILE=mace-models/micro-models/har-cnn/har-cnn.yml
CONF_FILE=micro/pretrained_models/har-cnn/har-cnn.yml
python tools/python/convert.py --config=${CONF_FILE} --enable_micro || exit -1
python tools/python/run_micro.py --config $CONF_FILE --build --validate --model_name har_cnn || exit -1
python tools/python/run_micro.py --config $CONF_FILE --model_name har_cnn --build --benchmark || exit -1
CONF_FILE=mace-models/micro-models/har-cnn/har-cnn-bf16.yml
CONF_FILE=micro/pretrained_models/har-cnn/har-cnn-bf16.yml
python tools/python/convert.py --config=${CONF_FILE} --enable_micro || exit -1
python tools/python/run_micro.py --config $CONF_FILE --build --validate --model_name har_cnn || exit -1
CONF_FILE=mace-models/micro-models/keras/mnist/mnist.yml
CONF_FILE=micro/pretrained_models/keras/mnist/mnist.yml
python3 tools/python/convert.py --config=${CONF_FILE} --enable_micro || exit -1
python3 tools/python/run_micro.py --config $CONF_FILE --build --validate --model_name mnist || exit -1
CONF_FILE=mace-models/micro-models/keras/mnist/mnist-int8.yml
CONF_FILE=micro/pretrained_models/keras/mnist/mnist-int8.yml
python3 tools/python/convert.py --config=${CONF_FILE} --enable_micro || exit -1
python3 tools/python/run_micro.py --config $CONF_FILE --build --validate --model_name mnist_int8 || exit -1
CONF_FILE=mace-models/micro-models/keras/har/har.yml
CONF_FILE=micro/pretrained_models/keras/har/har.yml
python3 tools/python/convert.py --config=${CONF_FILE} --enable_micro || exit -1
python3 tools/python/run_micro.py --config $CONF_FILE --build --validate --model_name har || exit -1
CONF_FILE=mace-models/micro-models/keras/har/har-int8.yml
python3 tools/python/convert.py --config=${CONF_FILE} --enable_micro || exit -1
python3 tools/python/run_micro.py --config $CONF_FILE --build --validate --model_name har_int8 || exit -1
# CONF_FILE=micro/pretrained_models/keras/har/har-int8.yml
# python3 tools/python/convert.py --config=${CONF_FILE} --enable_micro || exit -1
# python3 tools/python/run_micro.py --config $CONF_FILE --build --validate --model_name har_int8 || exit -1
CONF_FILE=mace-models/micro-models/tensorflow/kws/kws-tc_resnet8.yml
CONF_FILE=micro/pretrained_models/tensorflow/kws/kws-tc_resnet8.yml
python tools/python/convert.py --config=${CONF_FILE} --enable_micro || exit -1
python tools/python/run_micro.py --config $CONF_FILE --build --validate --model_name kws_tc_resnet8 || exit -1
CONF_FILE=mace-models/micro-models/tensorflow/kws/kws-tc_resnet8-bf16.yml
CONF_FILE=micro/pretrained_models/tensorflow/kws/kws-tc_resnet8-bf16.yml
python tools/python/convert.py --config=${CONF_FILE} --enable_micro || exit -1
python tools/python/run_micro.py --config $CONF_FILE --build --validate --model_name kws_tc_resnet8_bf16 || exit -1
rm -rf mace-models
......@@ -20,6 +20,8 @@ from transform.base_converter import ReduceType
from transform.base_converter import RoundMode
from tensorflow import keras
from tensorflow.python.keras.layers import convolutional
from tensorflow.python.keras import activations
from quantize import quantize_util
from utils.util import mace_check
......@@ -32,6 +34,8 @@ from tensorflow_model_optimization.python.core.\
from tensorflow_model_optimization.python.core.\
quantization.keras.quantize_annotate import QuantizeAnnotate
import numpy as np
padding_mode = {
"valid": PaddingMode.VALID,
"same": PaddingMode.SAME
......@@ -74,7 +78,7 @@ def get_output(keras_op):
return keras_op.output
activation_type = {
activation_types_dict = {
"relu": ActivationType.RELU,
# 'relu6': ActivationType.RELUX,
# 'PReLU': ActivationType.PRELU,
......@@ -89,6 +93,7 @@ class KerasConverter(base_converter.ConverterInterface):
def __init__(self, option, src_model_file):
self._op_converters = {
keras.layers.InputLayer: self.convert_input_layer,
keras.layers.Flatten: self.convert_flatten,
keras.layers.Dense: self.convert_dense,
keras.layers.Conv2D: self.convert_conv2d,
......@@ -96,6 +101,11 @@ class KerasConverter(base_converter.ConverterInterface):
keras.layers.Dropout: self.convert_dropout,
keras.layers.DepthwiseConv2D: self.convert_depthwise_conv2d,
keras.layers.Softmax: self.convert_softmax,
keras.layers.BatchNormalization: self.convert_batch_normalization,
keras.layers.Activation: self.convert_activation,
keras.layers.GlobalAveragePooling2D:
self.convert_global_average_pooling2d,
keras.layers.Add: self.convert_add,
QuantizeLayer: self.convert_quantize_layer,
QuantizeWrapper: self.convert_quantize_wrapper,
}
......@@ -106,7 +116,8 @@ class KerasConverter(base_converter.ConverterInterface):
ConverterUtil.add_data_format_arg(self._mace_net_def, DataFormat.NHWC)
with tfmot.quantization.keras.quantize_scope():
self._keras_model = keras.models.load_model(src_model_file)
self._keras_model = keras.models.load_model(src_model_file,
compile=False)
def run(self):
for op in self._keras_model.layers:
......@@ -141,10 +152,24 @@ class KerasConverter(base_converter.ConverterInterface):
framework_type_arg.i = FrameworkType.KERAS.value
ConverterUtil.add_data_format_arg(op, DataFormat.NHWC)
op.input.append(get_input(keras_op).name)
op.output.append(get_output(keras_op).name)
input = get_input(keras_op)
if isinstance(input, list):
for e in input:
op.input.append(e.name)
else:
op.input.append(input.name)
output = get_output(keras_op)
mace_check(not isinstance(output, list), "only support one output")
op.output.append(output.name)
output_shape = op.output_shape.add()
output_shape.dims.extend(keras_shape2list(get_output(keras_op).shape))
output_shape.dims.extend(keras_shape2list(output.shape))
return op
def convert_input_layer(self, keras_op):
op = self.convert_general_op_with_input_output(keras_op)
op.type = MaceOp.Identity.name
return op
......@@ -268,6 +293,100 @@ class KerasConverter(base_converter.ConverterInterface):
return op
def convert_batch_normalization(self, keras_op):
op = self.convert_general_op_with_input_output(keras_op)
op.type = MaceOp.BatchNorm.name
gamma = keras_op.gamma.numpy()
beta = keras_op.beta.numpy()
mean = keras_op.moving_mean.numpy()
variance = keras_op.moving_variance.numpy()
epsilon = keras_op.epsilon
scale = (1.0 / np.sqrt(variance + epsilon)) * gamma
offset = (-mean * scale) + beta
scale_name = keras_op.name + '/scale:0'
offset_name = keras_op.name + '/offset:0'
self.add_numpy_tensor(scale_name, scale)
self.add_numpy_tensor(offset_name, offset)
op.input.extend([scale_name, offset_name])
return op
def convert_global_average_pooling2d(self, keras_op):
op = self.convert_general_op_with_input_output(keras_op)
op.type = MaceOp.Reduce.name
reduce_type_arg = op.arg.add()
reduce_type_arg.name = MaceKeyword.mace_reduce_type_str
reduce_type_arg.i = ReduceType.MEAN.value
axis_arg = op.arg.add()
axis_arg.name = MaceKeyword.mace_axis_str
axis_arg.ints.extend([1, 2])
keep_dims_arg = op.arg.add()
keep_dims_arg.name = MaceKeyword.mace_keepdims_str
keep_dims_arg.i = 1
origin_output_shape = copy.deepcopy(op.output_shape[0].dims)
op.output_shape[0].dims.insert(1, 1)
op.output_shape[0].dims.insert(1, 1)
output_name = op.output[0]
del op.output[:]
output_name_mid = output_name + "_mid_reshape"
op.output.append(output_name_mid)
op_reshape = self._mace_net_def.op.add()
op_reshape.name = keras_op.name + "_reshape"
op_reshape.type = MaceOp.Reshape.name
op_reshape.input.append(output_name_mid)
op_reshape.output.append(output_name)
output_shape = op_reshape.output_shape.add()
output_shape.dims.extend(origin_output_shape)
t_shape = list(origin_output_shape)
shape_tensor_name = op_reshape.name + "_dest_shape"
self.add_tensor(
shape_tensor_name, [len(t_shape)], mace_pb2.DT_INT32, t_shape
)
op_reshape.input.append(shape_tensor_name)
data_type_arg = op_reshape.arg.add()
data_type_arg.name = "T"
data_type_arg.i = dtype2mtype(keras_op.dtype)
framework_type_arg = op_reshape.arg.add()
framework_type_arg.name = MaceKeyword.mace_framework_type_str
framework_type_arg.i = FrameworkType.KERAS.value
ConverterUtil.add_data_format_arg(op_reshape, DataFormat.NHWC)
return op_reshape
def convert_activation(self, keras_op):
op = self.convert_general_op_with_input_output(keras_op)
activation = keras_op.activation
if activation == activations.linear:
op.type = MaceOp.Identity.name
elif activation is activations.relu:
op.type = MaceOp.Activation.name
type_arg = op.arg.add()
type_arg.name = MaceKeyword.mace_activation_type_str
type_arg.s = six.b("RELU")
elif activation == activations.softmax:
op.type = MaceOp.Softmax.name
else:
mace_check(False, "Unsupported activation")
return op
def convert_add(self, keras_op):
op = self.convert_general_op_with_input_output(keras_op)
op.type = MaceOp.Eltwise.name
type_arg = op.arg.add()
type_arg.name = MaceKeyword.mace_element_type_str
type_arg.i = EltwiseType.SUM.value
return op
def convert_quantize_layer(self, keras_op):
op = self._mace_net_def.op.add()
op.name = keras_op.name
......@@ -328,6 +447,24 @@ class KerasConverter(base_converter.ConverterInterface):
tensor.float_data.extend(keras_tensor.numpy().flat)
return tensor
def add_numpy_tensor(self, name, np_tensor):
tensor = self._mace_net_def.tensors.add()
tensor.name = name
tensor.dims.extend(np_tensor.shape)
tensor.data_type = dtype2mtype(np_tensor.dtype)
tensor.float_data.extend(np_tensor.flat)
return tensor
def add_tensor(self, name, shape, data_type, value):
tensor = self._mace_net_def.tensors.add()
tensor.name = name
tensor.dims.extend(list(shape))
tensor.data_type = data_type
if data_type == mace_pb2.DT_INT32:
tensor.int32_data.extend(value)
else:
tensor.float_data.extend(value)
def split_activation_op(self, keras_op, op):
activation = keras_op.get_config()["activation"]
if "class_name" in activation:
......@@ -358,7 +495,7 @@ class KerasConverter(base_converter.ConverterInterface):
activation_op.type = MaceOp.Activation.name
type_arg = activation_op.arg.add()
type_arg.name = MaceKeyword.mace_activation_type_str
type_arg.s = six.b(activation_type[activation].name)
type_arg.s = six.b(activation_types_dict[activation].name)
activation_op.input.append(activation_tmp_name)
activation_op.output.append(get_output(keras_op).name)
......
......@@ -408,6 +408,54 @@ def validate_megengine_model(model_file, input_file,
mge_output_value, validation_threshold, log_file)
def validate_keras_model(model_file,
input_file, mace_out_file,
input_names, input_shapes, input_data_formats,
output_names, output_shapes, output_data_formats,
validation_threshold, input_data_types, log_file):
from tensorflow import keras
import tensorflow_model_optimization as tfmot
if not os.path.isfile(model_file):
util.MaceLogger.error(
VALIDATION_MODULE,
"Input model file '" + model_file + "' does not exist!")
with tfmot.quantization.keras.quantize_scope():
keras_model = keras.models.load_model(model_file, compile=False)
input = []
for i in range(len(input_names)):
input_value = load_data(
util.formatted_file_name(input_file, input_names[i]),
input_data_types[i])
input_value = input_value.reshape(input_shapes[i])
if input_data_formats[i] == DataFormat.NCHW and \
len(input_shapes[i]) == 4:
input_value = input_value.transpose((0, 2, 3, 1))
elif input_data_formats[i] == DataFormat.OIHW and \
len(input_shapes[i]) == 4:
# OIHW -> HWIO
input_value = input_value.transpose((2, 3, 1, 0))
input.append(input_value)
output_values = keras_model.predict(input)
for i in range(len(output_names)):
output_file_name = util.formatted_file_name(
mace_out_file, output_names[i])
mace_out_value = load_data(
output_file_name,
get_data_type_by_value(output_values[i]))
if output_data_formats[i] == DataFormat.NCHW and \
len(output_shapes[i]) == 4:
mace_out_value = mace_out_value. \
reshape(output_shapes[i]).transpose((0, 2, 3, 1))
compare_output(output_names[i],
mace_out_value, output_values[i],
validation_threshold, log_file)
def validate(platform, model_file, weight_file, input_file, mace_out_file,
input_shape, output_shape, input_data_format,
output_data_format, input_node, output_node,
......@@ -458,3 +506,11 @@ def validate(platform, model_file, weight_file, input_file, mace_out_file,
output_data_format,
validation_threshold,
input_data_type, log_file)
elif platform == Platform.KERAS:
validate_keras_model(model_file, input_file, mace_out_file,
input_node, input_shape, input_data_format,
output_node, output_shape, output_data_format,
validation_threshold, input_data_type,
log_file)
else:
mace_check(False, "Unsupported platform")
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册