提交 3d195176 编写于 作者: 李寅

Init quantize project:

1. Quantize weights
2. Add quantize-dequantze nodes
上级 b23889d5
......@@ -13,6 +13,8 @@
// limitations under the License.
#include <utility>
#include <algorithm>
#include <limits>
#include "mace/core/macros.h"
#include "mace/core/net.h"
......@@ -125,6 +127,24 @@ MaceStatus SerialNet::Run(RunMetadata *run_metadata) {
VLOG(3) << "Operator " << op->debug_def().name()
<< " has shape: " << MakeString(op->Output(0)->shape());
if (EnvEnabled("MACE_LOG_TENSOR_RANGE") && device_type_ == CPU) {
for (int i = 0; i < op->OutputSize(); ++i) {
int data_type = op->GetOptionalArg("T", static_cast<int>(DT_FLOAT));
if (data_type == static_cast<int>(DT_FLOAT)) {
float max_v = std::numeric_limits<float>::lowest();
float min_v = std::numeric_limits<float>::max();
Tensor::MappingGuard guard(op->Output(i));
const float *output_data = op->Output(i)->data<float>();
for (index_t j = 0; j < op->Output(i)->size(); ++j) {
max_v = std::max(max_v, output_data[j]);
min_v = std::min(min_v, output_data[j]);
}
LOG(INFO) << "Tensor range @@" << op->debug_def().output(i)
<< "@@" << min_v << "," << max_v;
}
}
}
}
return MACE_SUCCESS;
......
......@@ -161,6 +161,8 @@ MaceStatus Workspace::LoadModelTensor(const NetDef &net_def,
const_tensor.data_type()));
tensor->Reshape(dims);
tensor->SetScale(const_tensor.scale());
tensor->SetZeroPoint(const_tensor.zero_point());
tensor_map_[const_tensor.name()] = std::move(tensor);
}
}
......@@ -170,37 +172,48 @@ MaceStatus Workspace::LoadModelTensor(const NetDef &net_def,
MaceStatus status = CreateOutputTensorBuffer(net_def, type);
if (status != MaceStatus::MACE_SUCCESS) return status;
}
if (type == DeviceType::CPU && net_def.has_quantize_info()) {
for (const auto
&activation_info: net_def.quantize_info().activation_info()) {
MACE_CHECK(HasTensor(activation_info.tensor_name()),
"Quantize info exist for non-existed tensor",
activation_info.tensor_name());
Tensor *tensor = GetTensor(activation_info.tensor_name());
tensor->SetScale(activation_info.scale());
tensor->SetZeroPoint(activation_info.zero_point());
}
}
return MaceStatus::MACE_SUCCESS;
}
MaceStatus Workspace::CreateOutputTensorBuffer(const NetDef &net_def,
DeviceType device_type) {
if (!net_def.has_mem_arena() || net_def.mem_arena().mem_block_size() == 0) {
return MaceStatus::MACE_SUCCESS;
}
DataType dtype = DataType::DT_INVALID;
// We use the data type of the first op with mem id,
// as CPU&GPU have consistent data type for each layer for now.
// As DSP may have different data output type for each op,
// we stick to the same concept.
for (auto &op : net_def.op()) {
// TODO(liuqi): refactor to add device_type to OperatorDef
const int op_device =
ProtoArgHelper::GetOptionalArg<OperatorDef, int>(
op, "device", static_cast<int>(device_type));
if (op_device == device_type && !op.mem_id().empty()) {
const DataType op_dtype = static_cast<DataType>(
if (net_def.mem_arena().mem_block_size() > 0) {
// We use the data type of the first op with mem id,
// as CPU&GPU have consistent data type for each layer for now.
// As DSP may have different data output type for each op,
// we stick to the same concept.
for (auto &op : net_def.op()) {
// TODO(liuqi): refactor to add device_type to OperatorDef
const int op_device =
ProtoArgHelper::GetOptionalArg<OperatorDef, int>(
op, "T", static_cast<int>(DT_FLOAT)));
if (op_dtype != DataType::DT_INVALID) {
dtype = op_dtype;
// find first valid data type, break
break;
op, "device", static_cast<int>(device_type));
if (op_device == device_type && !op.mem_id().empty()) {
const DataType op_dtype = static_cast<DataType>(
ProtoArgHelper::GetOptionalArg<OperatorDef, int>(
op, "T", static_cast<int>(DT_FLOAT)));
if (op_dtype != DataType::DT_INVALID) {
dtype = op_dtype;
// find first valid data type, break
break;
}
}
}
MACE_CHECK(dtype != DataType::DT_INVALID, "data type is invalid.");
}
MACE_CHECK(dtype != DataType::DT_INVALID, "data type is invalid.");
// TODO(liyin): memory block should not have concept of type, but to be
// consistent with gpu, all memory block use float/half as unit
for (auto &mem_block : net_def.mem_arena().mem_block()) {
......@@ -239,36 +252,58 @@ MaceStatus Workspace::CreateOutputTensorBuffer(const NetDef &net_def,
const int op_device =
ProtoArgHelper::GetOptionalArg<OperatorDef, int>(
op, "device", static_cast<int>(device_type));
if (op_device == device_type && !op.mem_id().empty()
&& ShouldPreallocateMemoryForOp(op)) {
auto mem_ids = op.mem_id();
int count = mem_ids.size();
for (int i = 0; i < count; ++i) {
DataType output_type;
if (i < op.output_type_size()) {
output_type = op.output_type(i);
} else {
output_type = dtype;
if (op_device == device_type) {
if (!op.mem_id().empty()
&& ShouldPreallocateMemoryForOp(op)) {
auto mem_ids = op.mem_id();
int count = mem_ids.size();
for (int i = 0; i < count; ++i) {
DataType output_type;
if (i < op.output_type_size()) {
output_type = op.output_type(i);
} else {
output_type = dtype;
}
std::unique_ptr<Tensor> tensor
(new Tensor(preallocated_allocator_.GetBuffer(mem_ids[i]),
output_type));
tensor->SetSourceOpName(op.name());
if (device_type == DeviceType::GPU) {
VLOG(3) << "Tensor: " << op.name() << "(" << op.type() << ")"
<< " Mem: " << mem_ids[i]
<< " Image shape: "
<< dynamic_cast<Image *>(tensor->UnderlyingBuffer())
->image_shape()[0]
<< ", "
<< dynamic_cast<Image *>(tensor->UnderlyingBuffer())
->image_shape()[1];
} else if (device_type == DeviceType::CPU) {
VLOG(3) << "Tensor: " << op.name() << "(" << op.type() << ")"
<< " Mem: " << mem_ids[i]
<< ", Buffer size: " << tensor->UnderlyingBuffer()->size();
}
tensor_map_[op.output(i)] = std::move(tensor);
}
std::unique_ptr<Tensor> tensor
(new Tensor(preallocated_allocator_.GetBuffer(mem_ids[i]),
output_type));
tensor->SetSourceOpName(op.name());
if (device_type == DeviceType::GPU) {
VLOG(3) << "Tensor: " << op.name() << "(" << op.type() << ")"
<< " Mem: " << mem_ids[i]
<< " Image shape: "
<< dynamic_cast<Image *>(tensor->UnderlyingBuffer())
->image_shape()[0]
<< ", "
<< dynamic_cast<Image *>(tensor->UnderlyingBuffer())
->image_shape()[1];
} else if (device_type == DeviceType::CPU) {
VLOG(3) << "Tensor: " << op.name() << "(" << op.type() << ")"
<< " Mem: " << mem_ids[i]
<< ", Buffer size: " << tensor->UnderlyingBuffer()->size();
} else {
for (int i = 0; i < op.output().size(); ++i) {
MACE_CHECK(
op.output_type_size() == 0
|| op.output_size()
== op.output_type_size(),
"operator output size != operator output type size",
op.output_size(),
op.output_type_size());
DataType output_type;
if (i < op.output_type_size()) {
output_type = op.output_type(i);
} else {
output_type = static_cast<DataType>(ProtoArgHelper::GetOptionalArg(
op, "T", static_cast<int>(DT_FLOAT)));
}
CreateTensor(op.output(i),
GetDeviceAllocator(device_type),
output_type);
}
tensor_map_[op.output(i)] = std::move(tensor);
}
}
}
......
......@@ -488,6 +488,10 @@ struct Conv2dFunctor<DeviceType::CPU, float> : Conv2dFunctorBase {
const index_t extra_output_shape[4] =
{batch, channels, extra_output_height, extra_output_width};
// make host compiler happy
MACE_UNUSED(extra_input_shape);
MACE_UNUSED(extra_output_shape);
// decide which convolution function to call
if (use_winograd) {
transformed_input.Reshape(transformed_input_shape);
......
......@@ -217,6 +217,10 @@ struct DepthwiseConv2dFunctor<DeviceType::CPU, float>
const index_t input_shape[4] =
{batch, input_channels, input_height, input_width};
// make host compiler happy
MACE_UNUSED(pad_hw);
MACE_UNUSED(input_shape);
if (filter_h == 3 && filter_w == 3 && stride_h == 1 && stride_w == 1
&& dilation_h == 1 && dilation_w == 1) {
conv_func = [=](const float *input, float *output) {
......
......@@ -58,10 +58,11 @@ inline void AdjustRange(const float in_min_data,
if (fabs(quantized_zero - quantized_zero_near_int) > kEps) {
if (quantized_zero < quantized_zero_near_int || non_zero) {
// keep out_max fixed, and move out_min
*scale = out_max / (quantized_max - quantized_zero_near_int);
*zero_point = static_cast<int32_t>(std::ceil(quantized_zero));
*scale = out_max / (quantized_max - *zero_point);
} else {
// keep out_min fixed, and move out_max
*scale = -out_min / quantized_zero_near_int;
*scale = out_min / (quantized_min - *zero_point);
}
}
} else if (out_min > -kEps) {
......@@ -96,6 +97,18 @@ inline void FindMinMax(const float *input,
*max_val = max_v;
}
template<typename T>
inline void QuantizeWithScaleAndZeropoint(const float *input,
const index_t size,
float scale,
int32_t zero_point,
T *output) {
float recip_scale = 1 / scale;
for (int i = 0; i < size; ++i) {
output[i] = Saturate<T>(roundf(zero_point + recip_scale * input[i]));
}
}
template<typename T>
inline void Quantize(const float *input,
const index_t size,
......@@ -110,10 +123,7 @@ inline void Quantize(const float *input,
AdjustRange<T>(in_min_data, in_max_data, non_zero,
scale, zero_point);
float recip_scale = 1 / *scale;
for (int i = 0; i < size; ++i) {
output[i] = Saturate<T>(roundf(*zero_point + recip_scale * input[i]));
}
QuantizeWithScaleAndZeropoint(input, size, *scale, *zero_point, output);
}
template<typename T>
......@@ -143,16 +153,24 @@ struct QuantizeFunctor<CPU, uint8_t> {
Tensor::MappingGuard output_guard(output);
const float *input_data = input->data<float>();
uint8_t *output_data = output->mutable_data<uint8_t>();
float scale;
int32_t zero_point;
Quantize(input_data,
input->size(),
non_zero,
output_data,
&scale,
&zero_point);
output->SetScale(scale);
output->SetZeroPoint(zero_point);
if (output->scale() > 0.f) {
QuantizeWithScaleAndZeropoint(input_data,
input->size(),
output->scale(),
output->zero_point(),
output_data);
} else {
float scale;
int32_t zero_point;
Quantize(input_data,
input->size(),
non_zero,
output_data,
&scale,
&zero_point);
output->SetScale(scale);
output->SetZeroPoint(zero_point);
}
return MACE_SUCCESS;
}
......
......@@ -34,6 +34,8 @@ message ConstTensor {
optional string name = 5;
optional int64 offset = 6;
optional int64 data_size = 7;
optional float scale = 8;
optional int32 zero_point = 9;
optional uint32 node_id = 100;
}
......@@ -104,12 +106,23 @@ message OutputInfo {
optional DataType data_type = 5 [default = DT_FLOAT];
}
message QuantizeActivationInfo {
optional string tensor_name = 1;
optional float scale = 2;
optional int32 zero_point = 3;
}
message QuantizeInfo {
repeated QuantizeActivationInfo activation_info = 1;
}
message NetDef {
optional string name = 1;
repeated OperatorDef op = 2;
optional string version = 3;
repeated Argument arg = 4;
repeated ConstTensor tensors = 5;
optional QuantizeInfo quantize_info = 6;
// for mem optimization
optional MemoryArena mem_arena = 10;
......
py_library(
name = "quantization_lib",
srcs = [
"quantization/quantize_util.py",
],
srcs_version = "PY2AND3",
)
py_library(
name = "converter_lib",
srcs = [
......@@ -12,6 +20,7 @@ py_library(
],
srcs_version = "PY2AND3",
deps = [
":quantization_lib",
":memory_optimizer",
"//mace/proto:mace_py",
"//third_party/caffe:caffe_py",
......
......@@ -96,12 +96,12 @@ def main(unused_args):
print ("runtime %s is not supported." % FLAGS.runtime)
sys.exit(-1)
option = cvt.ConverterOption()
if FLAGS.graph_optimize_options:
option = cvt.ConverterOption(
FLAGS.graph_optimize_options.split(','))
else:
option = cvt.ConverterOption()
option.transformer_option = FLAGS.graph_optimize_options.split(',')
option.winograd = FLAGS.winograd
option.quantize = FLAGS.quantize
option.quantize_range_file = FLAGS.quantize_range_file
input_node_names = FLAGS.input_node.split(',')
input_node_shapes = FLAGS.input_shape.split(':')
......@@ -119,6 +119,8 @@ def main(unused_args):
output_node.name = output_node_names[i]
option.add_output_node(output_node)
option.build()
print("Transform model to one that can better run on device")
if FLAGS.runtime == 'dsp':
mace_check(FLAGS.platform == 'tensorflow',
......@@ -297,6 +299,18 @@ def parse_args():
type=str,
default="",
help="graph optimize options")
parser.add_argument(
"--quantize",
type=str2bool,
nargs='?',
const=False,
default=False,
help="quantize model")
parser.add_argument(
"--quantize_range_file",
type=str,
default="",
help="file path of quantize range for each tensor")
return parser.parse_known_args()
......
......@@ -97,7 +97,6 @@ MaceSupportedOps = [
'Proposal',
'Quantize',
'ReduceMean',
'Requantize',
'Reshape',
'ResizeBilinear',
'Slice',
......@@ -189,6 +188,9 @@ class TransformerRule(Enum):
ADD_IN_OUT_TENSOR_INFO = 20
ADD_MACE_INPUT_AND_OUTPUT_NODES = 21
UPDATE_FLOAT_OP_DATA_TYPE = 22
QUANTIZE_NODES = 23
ADD_QUANTIZE_TENSOR_RANGE = 24
QUANTIZE_WEIGHTS = 25
class ConverterInterface(object):
......@@ -228,40 +230,15 @@ class NodeInfo(object):
class ConverterOption(object):
"""A class for specifying options passed to converter tool"""
def __init__(self, transformers=None):
def __init__(self):
self._input_nodes = {}
self._output_nodes = {}
self._data_type = mace_pb2.DT_FLOAT
self._device = DeviceType.CPU.value
self._winograd = 0
if transformers:
self._transformer_option = [TransformerRule[transformer]
for transformer in transformers]
else:
self._transformer_option = [
TransformerRule.REMOVE_IDENTITY_OP,
TransformerRule.TRANSFORM_GLOBAL_POOLING,
TransformerRule.FOLD_RESHAPE,
TransformerRule.TRANSFORM_MATMUL_TO_FC,
TransformerRule.FOLD_BATCHNORM,
TransformerRule.FOLD_CONV_AND_BN,
TransformerRule.FOLD_DEPTHWISE_CONV_AND_BN,
TransformerRule.TRANSFORM_GPU_WINOGRAD,
TransformerRule.TRANSFORM_ADD_TO_BIASADD,
TransformerRule.FOLD_BIASADD,
TransformerRule.FLATTEN_ATROUS_CONV,
TransformerRule.FOLD_ACTIVATION,
TransformerRule.TRANSPOSE_FILTERS,
TransformerRule.TRANSPOSE_DATA_FORMAT,
TransformerRule.ADD_IN_OUT_TENSOR_INFO,
TransformerRule.TRANSFORM_GLOBAL_CONV_TO_FC,
TransformerRule.RESHAPE_FC_WEIGHT,
TransformerRule.TRANSFORM_BUFFER_IMAGE,
TransformerRule.ADD_DEVICE,
TransformerRule.UPDATE_FLOAT_OP_DATA_TYPE,
TransformerRule.ADD_MACE_INPUT_AND_OUTPUT_NODES,
TransformerRule.SORT_BY_EXECUTION,
]
self._quantize = False
self._quantize_range_file = ""
self._transformer_option = None
@property
def input_nodes(self):
......@@ -283,6 +260,14 @@ class ConverterOption(object):
def winograd(self):
return self._winograd
@property
def quantize(self):
return self._quantize
@property
def quantize_range_file(self):
return self._quantize_range_file
@property
def transformer_option(self):
return self._transformer_option
......@@ -315,6 +300,18 @@ class ConverterOption(object):
def winograd(self, winograd):
self._winograd = winograd
@quantize.setter
def quantize(self, quantize):
self._quantize = quantize
@quantize_range_file.setter
def quantize_range_file(self, quantize_range_file):
self._quantize_range_file = quantize_range_file
@transformer_option.setter
def transformer_option(self, transformer_option):
self._transformer_option = transformer_option
def disable_transpose_filters(self):
if TransformerRule.TRANSPOSE_FILTERS in self._transformer_option:
self._transformer_option.remove(TransformerRule.TRANSPOSE_FILTERS)
......@@ -323,6 +320,58 @@ class ConverterOption(object):
if TransformerRule.TRANSPOSE_FILTERS not in self._transformer_option:
self._transformer_option.append(TransformerRule.TRANSPOSE_FILTERS)
def build(self):
if self._transformer_option:
self._transformer_option = [TransformerRule[transformer]
for transformer in self._transformer_option] # noqa
else:
if not self._quantize:
self._transformer_option = [
TransformerRule.REMOVE_IDENTITY_OP,
TransformerRule.TRANSFORM_GLOBAL_POOLING,
TransformerRule.FOLD_RESHAPE,
TransformerRule.TRANSFORM_MATMUL_TO_FC,
TransformerRule.FOLD_BATCHNORM,
TransformerRule.FOLD_CONV_AND_BN,
TransformerRule.FOLD_DEPTHWISE_CONV_AND_BN,
TransformerRule.TRANSFORM_GPU_WINOGRAD,
TransformerRule.TRANSFORM_ADD_TO_BIASADD,
TransformerRule.FOLD_BIASADD,
TransformerRule.FLATTEN_ATROUS_CONV,
TransformerRule.FOLD_ACTIVATION,
TransformerRule.TRANSPOSE_FILTERS,
TransformerRule.TRANSPOSE_DATA_FORMAT,
TransformerRule.ADD_IN_OUT_TENSOR_INFO,
TransformerRule.TRANSFORM_GLOBAL_CONV_TO_FC,
TransformerRule.RESHAPE_FC_WEIGHT,
TransformerRule.TRANSFORM_BUFFER_IMAGE,
TransformerRule.ADD_DEVICE,
TransformerRule.UPDATE_FLOAT_OP_DATA_TYPE,
TransformerRule.ADD_MACE_INPUT_AND_OUTPUT_NODES,
TransformerRule.SORT_BY_EXECUTION,
]
else:
self._transformer_option = [
TransformerRule.REMOVE_IDENTITY_OP,
TransformerRule.TRANSFORM_GLOBAL_POOLING,
TransformerRule.FOLD_RESHAPE,
TransformerRule.TRANSFORM_MATMUL_TO_FC,
TransformerRule.FOLD_BATCHNORM,
TransformerRule.FOLD_CONV_AND_BN,
TransformerRule.FOLD_DEPTHWISE_CONV_AND_BN,
TransformerRule.TRANSFORM_GPU_WINOGRAD,
TransformerRule.TRANSFORM_ADD_TO_BIASADD,
TransformerRule.FOLD_BIASADD,
TransformerRule.FLATTEN_ATROUS_CONV,
TransformerRule.FOLD_ACTIVATION,
TransformerRule.ADD_IN_OUT_TENSOR_INFO,
TransformerRule.QUANTIZE_NODES,
TransformerRule.ADD_QUANTIZE_TENSOR_RANGE,
TransformerRule.QUANTIZE_WEIGHTS,
TransformerRule.ADD_DEVICE,
TransformerRule.SORT_BY_EXECUTION,
]
class ConverterUtil(object):
@staticmethod
......@@ -338,6 +387,12 @@ class ConverterUtil(object):
data_format_arg.name = MaceKeyword.mace_data_format_str
data_format_arg.i = data_format.value
@staticmethod
def add_data_type_arg(op, data_type):
data_type_arg = op.arg.add()
data_type_arg.name = MaceKeyword.mace_op_data_type_str
data_type_arg.i = data_type
@staticmethod
def data_format(op):
arg = ConverterUtil.get_arg(op, MaceKeyword.mace_data_format_str)
......
......@@ -31,7 +31,7 @@ from mace.python.tools.converter_tool.base_converter import TransformerRule
from mace.python.tools.convert_util import calculate_image_shape
from mace.python.tools.convert_util import mace_check
from mace.python.tools.convert_util import OpenCLBufferType
from mace.python.tools.quantization import quantize_util
OPENCL_IMAGE_MAX_SIZE = 16384
......@@ -73,6 +73,12 @@ class Transformer(base_converter.ConverterInterface):
TransformerRule.RESHAPE_FC_WEIGHT: self.reshape_fc_weight,
TransformerRule.TRANSFORM_BUFFER_IMAGE:
self.transform_buffer_image,
TransformerRule.QUANTIZE_NODES:
self.quantize_nodes,
TransformerRule.ADD_QUANTIZE_TENSOR_RANGE:
self.add_quantize_tensor_range,
TransformerRule.QUANTIZE_WEIGHTS:
self.quantize_weights,
TransformerRule.ADD_DEVICE:
self.add_device,
TransformerRule.UPDATE_FLOAT_OP_DATA_TYPE:
......@@ -93,6 +99,8 @@ class Transformer(base_converter.ConverterInterface):
self._target_data_format = DataFormat.NHWC
self._input_output_added = False
self._opencl_max_image_size = [0, 0]
self._quantize_activation_info = {}
self._quantized_tensor = set()
if self._option.device == DeviceType.CPU.value:
self._target_data_format = DataFormat.NCHW
......@@ -854,6 +862,7 @@ class Transformer(base_converter.ConverterInterface):
else:
op.type = MaceOp.Identity.name
ConverterUtil.add_data_type_arg(op, mace_pb2.DT_FLOAT)
ConverterUtil.add_data_format_arg(op, DataFormat.NCHW)
for output_node in self._option.output_nodes.values():
......@@ -877,6 +886,7 @@ class Transformer(base_converter.ConverterInterface):
ConverterUtil.add_data_format_arg(op, DataFormat.NHWC)
else:
op.type = MaceOp.Identity.name
ConverterUtil.add_data_type_arg(op, mace_pb2.DT_FLOAT)
self._input_output_added = True
......@@ -963,6 +973,7 @@ class Transformer(base_converter.ConverterInterface):
arg = op_def.arg.add()
arg.name = MaceKeyword.mace_mode
arg.i = 0
ConverterUtil.add_data_type_arg(op_def, mace_pb2.DT_FLOAT)
tensor_shape = list(self._consts[input_name].dims)
if input_type == OpenCLBufferType.WINOGRAD_FILTER:
......@@ -1054,6 +1065,7 @@ class Transformer(base_converter.ConverterInterface):
arg.name = MaceKeyword.mace_buffer_type
arg.i = OpenCLBufferType.IN_OUT_CHANNEL.value
ConverterUtil.add_data_type_arg(op_def, mace_pb2.DT_FLOAT)
ConverterUtil.add_data_format_arg(op_def, DataFormat.NHWC)
for output_node in self._option.output_nodes.values():
......@@ -1072,6 +1084,7 @@ class Transformer(base_converter.ConverterInterface):
arg.name = MaceKeyword.mace_buffer_type
arg.i = OpenCLBufferType.IN_OUT_CHANNEL.value
ConverterUtil.add_data_type_arg(op_def, mace_pb2.DT_FLOAT)
ConverterUtil.add_data_format_arg(op_def, DataFormat.NHWC)
self._input_output_added = True
......@@ -1276,6 +1289,7 @@ class Transformer(base_converter.ConverterInterface):
output_shape = op_def.output_shape.add()
output_shape.dims.extend(input_node.shape)
ConverterUtil.add_data_type_arg(op_def, mace_pb2.DT_FLOAT)
ConverterUtil.add_data_format_arg(op_def, DataFormat.NHWC)
for output_node in self._option.output_nodes.values():
......@@ -1290,6 +1304,8 @@ class Transformer(base_converter.ConverterInterface):
output_shape.dims.extend(
self._producer[output_node.name].output_shape[0].dims)
ConverterUtil.add_data_type_arg(op_def, mace_pb2.DT_FLOAT)
def sort_by_execution(self):
print("Sort by execution")
net = self._model
......@@ -1311,3 +1327,103 @@ class Transformer(base_converter.ConverterInterface):
print("%s (%s): %s" % (op.name, op.type, [
out_shape.dims for out_shape in op.output_shape]))
return False
def quantize_nodes(self):
print("Add mace quantize and dequantize nodes")
for op in self._model.op:
data_type_arg = ConverterUtil.get_arg(
op, MaceKeyword.mace_op_data_type_str)
mace_check(data_type_arg, "Data type does not exist for %s(%s)"
% (op.name, op.type))
if data_type_arg.i == mace_pb2.DT_FLOAT:
data_type_arg.i = mace_pb2.DT_UINT8
else:
mace_check(False,
"Quantization only support float ops, "
"but get %s(%s)"
% (op.name, op.type))
for input_node in self._option.input_nodes.values():
new_input_name = MaceKeyword.mace_input_node_name \
+ '_' + input_node.name
op_def = self._model.op.add()
op_def.name = self.normalize_op_name(input_node.name)
op_def.type = MaceOp.Quantize.name
op_def.input.extend([new_input_name])
op_def.output.extend([input_node.name])
output_shape = op_def.output_shape.add()
output_shape.dims.extend(input_node.shape)
ConverterUtil.add_data_type_arg(op_def, mace_pb2.DT_UINT8)
ConverterUtil.add_data_format_arg(op_def, DataFormat.NHWC)
for output_node in self._option.output_nodes.values():
output_name = MaceKeyword.mace_output_node_name \
+ '_' + output_node.name
op_def = self._model.op.add()
op_def.name = self.normalize_op_name(output_name)
op_def.type = MaceOp.Dequantize.name
op_def.input.extend([output_node.name])
op_def.output.extend([output_name])
output_shape = op_def.output_shape.add()
output_shape.dims.extend(
self._producer[output_node.name].output_shape[0].dims)
ConverterUtil.add_data_type_arg(op_def, mace_pb2.DT_UINT8)
self._input_output_added = True
def add_quantize_tensor_range(self):
print("Add quantize tensor range")
net = self._model
range_file = self._option.quantize_range_file
with open(range_file) as f:
for line in f:
tensor_name, minmax = line.split("@@")
min_val, max_val = [float(i) for i in
minmax.strip().split(",")]
scale, zero = quantize_util.adjust_range(min_val, max_val,
non_zero=False)
activation_info = net.quantize_info.activation_info.add()
activation_info.tensor_name = tensor_name
activation_info.scale = scale
activation_info.zero_point = zero
self._quantize_activation_info[tensor_name] = activation_info
def quantize_tensor(self, tensor):
"""Assume biasadd has been already folded with convolution and fc"""
if tensor.data_type == mace_pb2.DT_FLOAT:
ops = self._consumers.get(tensor.name, None)
if len(ops) == 1 and ops[0].type in [MaceOp.Conv2D.name,
MaceOp.Deconv2D.name,
MaceOp.DepthwiseConv2d.name,
MaceOp.FullyConnected.name] \
and len(ops[0].input) >= 3 \
and ops[0].input[2] == tensor.name:
conv_op = ops[0]
scale_input = self._quantize_activation_info[
conv_op.input[0]].scale
if conv_op.input[1] not in self._quantized_tensor:
self.quantize_tensor(self._consts[conv_op.input[1]])
scale_filter = self._consts[conv_op.input[1]].scale
scale = scale_input * scale_filter
quantized_tensor = quantize_util.quantize_with_scale_and_zero(
tensor.float_data, scale, 0)
tensor.data_type = mace_pb2.DT_INT32
else:
quantized_tensor = quantize_util.quantize(tensor.float_data)
tensor.data_type = mace_pb2.DT_UINT8
del tensor.float_data[:]
tensor.int32_data.extend(quantized_tensor.data)
tensor.scale = quantized_tensor.scale
tensor.zero_point = quantized_tensor.zero
self._quantized_tensor.update([tensor.name])
def quantize_weights(self):
print("Quantize weights")
net = self._model
for tensor in net.tensors:
self.quantize_tensor(tensor)
......@@ -138,6 +138,19 @@ void CreateMemoryArena(mace::MemoryArena *mem_arena) {
}
{% endif %}
void AddQuantizeInfo(NetDef *net_def) {
MACE_LATENCY_LOGGER(1, "Add quantize info");
(void) net_def;
{% for i in range(net.quantize_info.activation_info|length) %}
mace::QuantizeActivationInfo *activation_info{{i}} =
net_def->mutable_quantize_info()->add_activation_info();
activation_info{{i}}->set_tensor_name("{{net.quantize_info.activation_info[i].tensor_name}}");
activation_info{{i}}->set_scale({{net.quantize_info.activation_info[i].scale}});
activation_info{{i}}->set_zero_point({{net.quantize_info.activation_info[i].zero_point}});
{% endfor %}
}
} // namespace
......@@ -166,6 +179,8 @@ const std::shared_ptr<NetDef> CreateNet() {
CreateOutputInfo(net_def.get());
{% endif %}
AddQuantizeInfo(net_def.get());
return net_def;
}
......
......@@ -98,30 +98,6 @@ def obfuscate_name(net_def):
op.output[i] = in_out_map[op.output[i]]
def normalize_op_name(op_name):
idx = op_name.rfind(':')
if idx == -1:
return op_name
else:
return op_name[:idx]
def rename_tensor(net_def):
tensor_map = {}
for t in net_def.tensors:
if t.name not in tensor_map:
tensor_map[t.name] = "_" + normalize_op_name(t.name).replace("/",
"_")
t.name = tensor_map[t.name]
for op in net_def.op:
for i in range(len(op.input)):
if op.input[i] in tensor_map:
op.input[i] = tensor_map[op.input[i]]
for i in range(len(op.output)):
if op.output[i] in tensor_map:
op.output[i] = tensor_map[op.output[i]]
def stringfy(value):
return ', '.join('"{0}"'.format(w) for w in value)
......@@ -301,8 +277,6 @@ def save_model(net_def, model_checksum, weight_checksum, template_dir,
winograd_conv, data_type, model_graph_format):
if obfuscate:
obfuscate_name(net_def)
else:
rename_tensor(net_def)
output_dir = output_dir + '/'
# update tensor type
......
import argparse
import numpy as np
class QuantizeStat(object):
def __init__(self):
pass
@staticmethod
def run(log_file, percentile):
res = {}
tensor_ranges = {}
with open(log_file) as log:
for line in log:
if line.find("Tensor range @@") != -1:
tensor_name, minmax = line.split("@@")[1:]
min_val, max_val = [float(i) for i in
minmax.strip().split(",")]
if tensor_name not in tensor_ranges:
tensor_ranges[tensor_name] = ([], [])
tensor_ranges[tensor_name][0].append(min_val)
tensor_ranges[tensor_name][1].append(max_val)
for tensor_name in tensor_ranges:
tensor_min = np.percentile(tensor_ranges[tensor_name][0],
percentile)
tensor_max = np.percentile(tensor_ranges[tensor_name][1],
100 - percentile)
assert tensor_min < tensor_max
res[tensor_name] = (tensor_min, tensor_max)
return res
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
"--log_file",
type=str,
default="",
help="path of log file that records tensor range")
parser.add_argument(
"--percentile",
type=int,
default=5,
help="range percentile")
FLAGS, unparsed = parser.parse_known_args()
res = QuantizeStat.run(FLAGS.log_file, FLAGS.percentile)
for tensor in res:
print("%s@@%f,%f" % (tensor, res[tensor][0], res[tensor][1]))
import numpy as np
import math
class QuantizedData(object):
def __init__(self):
self._data = None
self._scale = 0
self._zero = 0
@property
def data(self):
return self._data
@property
def scale(self):
return self._scale
@property
def zero(self):
return self._zero
@data.setter
def data(self, data):
self._data = data
@scale.setter
def scale(self, scale):
self._scale = scale
@zero.setter
def zero(self, zero):
self._zero = zero
def adjust_range(in_min, in_max, non_zero):
out_max = max(0.0, in_max)
out_min = min(0.0, in_min)
if non_zero:
out_min = min(out_min, in_min - (out_max - in_min) / 254.0)
scale = (out_max - out_min) / 255.0
eps = 1e-6
if out_min < -eps and out_max > eps:
zero = -out_min / scale
zero_int = int(round(zero))
if abs(zero - zero_int) > eps:
if zero < zero_int or non_zero:
zero_int = int(math.ceil(zero))
scale = out_max / (255.0 - zero_int)
else:
scale = -out_min / zero_int
elif out_min > -eps:
zero_int = 0
else:
zero_int = 255
return scale, zero_int
def cal_multiplier_and_shift(scale):
"""
In order to use gemmlowp, we need to use gemmlowp-like transform
:param scale:
:return: multiplier, shift
"""
assert scale > 0, "scale should > 0, but get %s" % scale
assert scale < 1, "scale should < 1, but get %s" % scale
multiplier = scale
s = 0
# make range [1/2, 1)
while multiplier < 0.5:
multiplier *= 2.0
s += 1
# convert scale to fixed-point
q = int(round(multiplier * (1 << 31)))
assert q <= (1 << 31)
if q == (1 << 31):
q /= 2
s -= 1
assert s >= 0
return q, s
def quantize_with_scale_and_zero(data, scale, zero):
output = np.round(zero + data / scale).astype(int)
quantized_data = QuantizedData()
quantized_data.data = output
quantized_data.scale = scale
quantized_data.zero = zero
return quantized_data
def quantize(data):
np_data = np.array(data).astype(float)
in_min = np_data.min()
in_max = np_data.max()
scale, zero = adjust_range(in_min, in_max, non_zero=True)
output = np.clip((np.round(zero + data / scale).astype(int)), 0, 255)
quantized_data = QuantizedData()
quantized_data.data = output
quantized_data.scale = scale
quantized_data.zero = zero
return quantized_data
def dequantize(quantized_data):
return quantized_data.scale * (quantized_data.data - quantized_data.zero)
import unittest
import numpy as np
import quantize_util
class TestQuantize(unittest.TestCase):
def test_quantize_dequantize(self):
test_input = np.random.rand(20, 30) * 5
quantized_data = quantize_util.quantize(test_input)
dequantized_output = quantize_util.dequantize(quantized_data)
np.testing.assert_array_almost_equal(test_input, dequantized_output, 2)
if __name__ == '__main__':
unittest.main()
......@@ -32,6 +32,8 @@ void CreateTensor{{tensor_info.id}}(mace::ConstTensor *const_tensor) {
{% endfor %}
const_tensor->set_data_type(static_cast<DataType>({{ tensor_info.data_type }}));
const_tensor->set_node_id({{ tensor.node_id }});
const_tensor->set_scale({{ tensor.scale }});
const_tensor->set_zero_point({{ tensor.zero_point }});
}
} // namespace {{tag}}
......
# Quantize stat build
cc_binary(
name = "quantize_stat",
srcs = ["quantize_stat.cc"],
copts = [
"-Werror",
"-Wextra",
],
linkopts = ["-fopenmp"],
linkstatic = 1,
deps = [
"//external:gflags_nothreads",
"//mace/codegen:generated_mace_engine_factory",
"//mace/codegen:generated_models",
"//mace/libmace",
],
)
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/**
* Usage:
* quantize_stat --model=mobi_mace.pb \
* --input=input_node \
* --output=output_node \
* --input_shape=1,224,224,3 \
* --output_shape=1,224,224,2 \
* --input_dir=input_data_dir \
* --output_file=mace.out \
* --model_data_file=model_data.data
*/
#include <malloc.h>
#include <dirent.h>
#include <stdint.h>
#include <cstdio>
#include <cstdlib>
#include <iostream>
#include <numeric>
#include "gflags/gflags.h"
#include "mace/public/mace.h"
#include "mace/public/mace_runtime.h"
#include "mace/utils/env_time.h"
#include "mace/utils/logging.h"
#include "mace/utils/utils.h"
#ifdef MODEL_GRAPH_FORMAT_CODE
#include "mace/codegen/engine/mace_engine_factory.h"
#endif
namespace mace {
namespace tools {
namespace quantization {
namespace str_util {
std::vector<std::string> Split(const std::string &str, char delims) {
std::vector<std::string> result;
if (str.empty()) {
result.push_back("");
return result;
}
std::string tmp = str;
while (!tmp.empty()) {
size_t next_offset = tmp.find(delims);
result.push_back(tmp.substr(0, next_offset));
if (next_offset == std::string::npos) {
break;
} else {
tmp = tmp.substr(next_offset + 1);
}
}
return result;
}
} // namespace str_util
void ParseShape(const std::string &str, std::vector<int64_t> *shape) {
std::string tmp = str;
while (!tmp.empty()) {
int dim = atoi(tmp.data());
shape->push_back(dim);
size_t next_offset = tmp.find(",");
if (next_offset == std::string::npos) {
break;
} else {
tmp = tmp.substr(next_offset + 1);
}
}
}
std::string FormatName(const std::string input) {
std::string res = input;
for (size_t i = 0; i < input.size(); ++i) {
if (!isalnum(res[i])) res[i] = '_';
}
return res;
}
DEFINE_string(model_name,
"",
"model name in yaml");
DEFINE_string(input_node,
"input_node0,input_node1",
"input nodes, separated by comma");
DEFINE_string(input_shape,
"1,224,224,3:1,1,1,10",
"input shapes, separated by colon and comma");
DEFINE_string(output_node,
"output_node0,output_node1",
"output nodes, separated by comma");
DEFINE_string(output_shape,
"1,224,224,2:1,1,1,10",
"output shapes, separated by colon and comma");
DEFINE_string(input_dir,
"",
"input directory name");
DEFINE_string(model_data_file,
"",
"model data file name, used when EMBED_MODEL_DATA set to 0 or 2");
DEFINE_string(model_file,
"",
"model file name, used when load mace model in pb");
DEFINE_int32(omp_num_threads, -1, "num of openmp threads");
bool RunModel(const std::string &model_name,
const std::vector<std::string> &input_names,
const std::vector<std::vector<int64_t>> &input_shapes,
const std::vector<std::string> &output_names,
const std::vector<std::vector<int64_t>> &output_shapes) {
MACE_RETURN_IF_ERROR(mace::SetOpenMPThreadPolicy(
FLAGS_omp_num_threads, CPUAffinityPolicy::AFFINITY_NONE));
std::vector<unsigned char> model_pb_data;
if (FLAGS_model_file != "") {
if (!mace::ReadBinaryFile(&model_pb_data, FLAGS_model_file)) {
LOG(FATAL) << "Failed to read file: " << FLAGS_model_file;
}
}
std::shared_ptr<mace::MaceEngine> engine;
// Create Engine
#ifdef MODEL_GRAPH_FORMAT_CODE
MACE_RETURN_IF_ERROR(
CreateMaceEngineFromCode(model_name,
FLAGS_model_data_file,
input_names,
output_names,
DeviceType::CPU,
&engine));
#else
(void) (model_name);
MACE_RETURN_IF_ERROR(
CreateMaceEngineFromProto(model_pb_data,
FLAGS_model_data_file,
input_names,
output_names,
DeviceType::CPU,
&engine));
#endif
const size_t input_count = input_names.size();
const size_t output_count = output_names.size();
std::map<std::string, mace::MaceTensor> inputs;
std::map<std::string, mace::MaceTensor> outputs;
std::map<std::string, int64_t> inputs_size;
for (size_t i = 0; i < input_count; ++i) {
int64_t input_size =
std::accumulate(input_shapes[i].begin(), input_shapes[i].end(), 1,
std::multiplies<int64_t>());
inputs_size[input_names[i]] = input_size;
auto buffer_in = std::shared_ptr<float>(new float[input_size],
std::default_delete<float[]>());
inputs[input_names[i]] = mace::MaceTensor(input_shapes[i], buffer_in);
}
for (size_t i = 0; i < output_count; ++i) {
int64_t output_size =
std::accumulate(output_shapes[i].begin(), output_shapes[i].end(), 1,
std::multiplies<int64_t>());
auto buffer_out = std::shared_ptr<float>(new float[output_size],
std::default_delete<float[]>());
outputs[output_names[i]] = mace::MaceTensor(output_shapes[i], buffer_out);
}
DIR *dir_parent;
struct dirent *entry;
dir_parent = opendir(FLAGS_input_dir.c_str());
if (dir_parent) {
while ((entry = readdir(dir_parent))) {
std::string file_name = std::string(entry->d_name);
std::string prefix = FormatName(input_names[0]);
if (file_name.find(prefix) == 0) {
std::string suffix = file_name.substr(prefix.size());
for (size_t i = 0; i < input_count; ++i) {
file_name = FLAGS_input_dir + "/" + FormatName(input_names[i])
+ suffix;
std::ifstream in_file(file_name, std::ios::in | std::ios::binary);
VLOG(2) << "Read " << file_name;
if (in_file.is_open()) {
in_file.read(reinterpret_cast<char *>(
inputs[input_names[i]].data().get()),
inputs_size[input_names[i]] * sizeof(float));
in_file.close();
} else {
LOG(INFO) << "Open input file failed";
return -1;
}
}
MACE_RETURN_IF_ERROR(engine->Run(inputs, &outputs));
}
}
closedir(dir_parent);
} else {
LOG(ERROR) << "Directory " << FLAGS_input_dir << " does not exist.";
}
return true;
}
int Main(int argc, char **argv) {
std::string usage = "quantize stat model\nusage: " + std::string(argv[0])
+ " [flags]";
gflags::SetUsageMessage(usage);
gflags::ParseCommandLineFlags(&argc, &argv, true);
LOG(INFO) << "model name: " << FLAGS_model_name;
LOG(INFO) << "mace version: " << MaceVersion();
LOG(INFO) << "input node: " << FLAGS_input_node;
LOG(INFO) << "input shape: " << FLAGS_input_shape;
LOG(INFO) << "output node: " << FLAGS_output_node;
LOG(INFO) << "output shape: " << FLAGS_output_shape;
LOG(INFO) << "input_dir: " << FLAGS_input_dir;
LOG(INFO) << "model_data_file: " << FLAGS_model_data_file;
LOG(INFO) << "model_file: " << FLAGS_model_file;
LOG(INFO) << "omp_num_threads: " << FLAGS_omp_num_threads;
std::vector<std::string> input_names = str_util::Split(FLAGS_input_node, ',');
std::vector<std::string> output_names =
str_util::Split(FLAGS_output_node, ',');
std::vector<std::string> input_shapes =
str_util::Split(FLAGS_input_shape, ':');
std::vector<std::string> output_shapes =
str_util::Split(FLAGS_output_shape, ':');
const size_t input_count = input_shapes.size();
const size_t output_count = output_shapes.size();
std::vector<std::vector<int64_t>> input_shape_vec(input_count);
std::vector<std::vector<int64_t>> output_shape_vec(output_count);
for (size_t i = 0; i < input_count; ++i) {
ParseShape(input_shapes[i], &input_shape_vec[i]);
}
for (size_t i = 0; i < output_count; ++i) {
ParseShape(output_shapes[i], &output_shape_vec[i]);
}
return RunModel(FLAGS_model_name, input_names, input_shape_vec,
output_names, output_shape_vec);
}
} // namespace quantization
} // namespace tools
} // namespace mace
int main(int argc, char **argv) {
mace::tools::quantization::Main(argc, argv);
}
......@@ -19,6 +19,7 @@
#include <map>
#include <sstream>
#include <string>
#include <cstdlib>
#include <utility>
#include <vector>
......@@ -162,5 +163,10 @@ std::vector<std::string> MapKeys(const std::map<std::string, T> &data) {
return keys;
}
inline bool EnvEnabled(std::string env_name) {
char *env = getenv(env_name.c_str());
return !(!env || env[0] == 0 || env[0] == '0');
}
} // namespace mace
#endif // MACE_UTILS_UTILS_H_
......@@ -70,6 +70,7 @@ MACE_RUN_STATIC_NAME = "mace_run_static"
MACE_RUN_DYNAMIC_NAME = "mace_run_dynamic"
MACE_RUN_STATIC_TARGET = "//mace/tools/validation:" + MACE_RUN_STATIC_NAME
MACE_RUN_DYNAMIC_TARGET = "//mace/tools/validation:" + MACE_RUN_DYNAMIC_NAME
QUANTIZE_STAT_TARGET = "//mace/tools/quantization:quantize_stat"
EXAMPLE_STATIC_NAME = "example_static"
EXAMPLE_DYNAMIC_NAME = "example_dynamic"
EXAMPLE_STATIC_TARGET = "//mace/examples/cli:" + EXAMPLE_STATIC_NAME
......@@ -185,6 +186,8 @@ class YAMLKeyword(object):
nnlib_graph_mode = 'nnlib_graph_mode'
obfuscate = 'obfuscate'
winograd = 'winograd'
quantize = 'quantize'
quantize_range_file = 'quantize_range_file'
validation_inputs_data = 'validation_inputs_data'
graph_optimize_options = 'graph_optimize_options' # internal use for now
......@@ -459,7 +462,8 @@ def format_model_config(flags):
for key in [YAMLKeyword.limit_opencl_kernel_time,
YAMLKeyword.nnlib_graph_mode,
YAMLKeyword.obfuscate,
YAMLKeyword.winograd]:
YAMLKeyword.winograd,
YAMLKeyword.quantize]:
value = model_config.get(key, "")
if value == "":
model_config[key] = 0
......@@ -705,6 +709,8 @@ def convert_model(configs):
model_config[YAMLKeyword.nnlib_graph_mode],
embed_model_data,
model_config[YAMLKeyword.winograd],
model_config[YAMLKeyword.quantize],
model_config.get(YAMLKeyword.quantize_range_file, ""),
model_config[YAMLKeyword.obfuscate],
configs[YAMLKeyword.model_graph_format],
data_type,
......@@ -871,6 +877,37 @@ def build_mace_run(configs, target_abi, enable_openmp, address_sanitizer,
mace_lib_type == MACELibType.dynamic)
def build_quantize_stat(configs):
library_name = configs[YAMLKeyword.library_name]
build_tmp_binary_dir = get_build_binary_dir(library_name, ABIType.host)
if os.path.exists(build_tmp_binary_dir):
sh.rm("-rf", build_tmp_binary_dir)
os.makedirs(build_tmp_binary_dir)
quantize_stat_target = QUANTIZE_STAT_TARGET
build_arg = ""
print (configs[YAMLKeyword.model_graph_format])
if configs[YAMLKeyword.model_graph_format] == ModelFormat.code:
mace_check(os.path.exists(ENGINE_CODEGEN_DIR),
ModuleName.RUN,
"You should convert model first.")
build_arg = "--per_file_copt=mace/tools/quantization/quantize_stat.cc@-DMODEL_GRAPH_FORMAT_CODE" # noqa
sh_commands.bazel_build(
quantize_stat_target,
abi=ABIType.host,
enable_openmp=True,
extra_args=build_arg
)
quantize_stat_filepath = build_tmp_binary_dir + "/quantize_stat"
if os.path.exists(quantize_stat_filepath):
sh.rm("-rf", quantize_stat_filepath)
sh.cp("-f", "bazel-bin/mace/tools/quantization/quantize_stat",
build_tmp_binary_dir)
def build_example(configs, target_abi, enable_openmp, mace_lib_type):
library_name = configs[YAMLKeyword.library_name]
hexagon_mode = get_hexagon_mode(configs)
......@@ -1196,6 +1233,59 @@ def run_specific_target(flags, configs, target_abi,
opencl_parameter_bin_path)
def run_quantize_stat(flags, configs):
library_name = configs[YAMLKeyword.library_name]
build_tmp_binary_dir = get_build_binary_dir(library_name, ABIType.host)
for model_name in configs[YAMLKeyword.models]:
check_model_converted(library_name, model_name,
configs[YAMLKeyword.model_graph_format],
configs[YAMLKeyword.model_data_format],
ABIType.host)
MaceLogger.header(
StringFormatter.block(
"Run model %s on %s" % (model_name, ABIType.host)))
model_config = configs[YAMLKeyword.models][model_name]
subgraphs = model_config[YAMLKeyword.subgraphs]
_, _, mace_model_dir = \
get_build_model_dirs(library_name, model_name, ABIType.host,
None, None,
model_config[YAMLKeyword.model_file_path])
mace_model_path = ""
if configs[YAMLKeyword.model_graph_format] == ModelFormat.file:
mace_model_path = "%s/%s.pb" % (mace_model_dir, model_name)
p = subprocess.Popen(
[
"env",
"MACE_CPP_MIN_VLOG_LEVEL=%s" % flags.vlog_level,
"MACE_LOG_TENSOR_RANGE=1",
"%s/%s" % (build_tmp_binary_dir, "quantize_stat"),
"--model_name=%s" % model_name,
"--input_node=%s" % ",".join(
subgraphs[0][YAMLKeyword.input_tensors]),
"--output_node=%s" % ",".join(
subgraphs[0][YAMLKeyword.output_tensors]),
"--input_shape=%s" % ":".join(
subgraphs[0][YAMLKeyword.input_shapes]),
"--output_shape=%s" % ":".join(
subgraphs[0][YAMLKeyword.output_shapes]),
"--input_dir=%s" % flags.input_dir,
"--model_data_file=%s/%s.data" % (mace_model_dir, model_name),
"--omp_num_threads=%s" % flags.omp_num_threads,
"--model_file=%s" % mace_model_path,
],
stderr=subprocess.PIPE,
stdout=subprocess.PIPE)
out, err = p.communicate()
stdout = err + out
print stdout
print("Running finished!\n")
def print_package_summary(package_path):
title = "Library"
header = ["key", "value"]
......@@ -1216,6 +1306,11 @@ def run_mace(flags):
clear_build_dirs(configs[YAMLKeyword.library_name])
if flags.quantize_stat:
build_quantize_stat(configs)
run_quantize_stat(flags, configs)
return
target_socs = configs[YAMLKeyword.target_socs]
if not target_socs or ALL_SOC_TAG in target_socs:
target_socs = sh_commands.adb_get_all_socs()
......@@ -1582,6 +1677,15 @@ def parse_args():
"--example",
action="store_true",
help="whether to run example.")
run.add_argument(
"--quantize_stat",
action="store_true",
help="whether to stat quantization range.")
run.add_argument(
"--input_dir",
type=str,
default="",
help="quantize stat input dir.")
benchmark = subparsers.add_parser(
'benchmark',
parents=[all_type_parent_parser, run_bm_parent_parser],
......
......@@ -490,6 +490,8 @@ def gen_model_code(model_codegen_dir,
dsp_mode,
embed_model_data,
winograd,
quantize,
quantize_range_file,
obfuscate,
model_graph_format,
data_type,
......@@ -516,6 +518,8 @@ def gen_model_code(model_codegen_dir,
"--dsp_mode=%s" % dsp_mode,
"--embed_model_data=%s" % embed_model_data,
"--winograd=%s" % winograd,
"--quantize=%s" % quantize,
"--quantize_range_file=%s" % quantize_range_file,
"--obfuscate=%s" % obfuscate,
"--output_dir=%s" % model_codegen_dir,
"--model_graph_format=%s" % model_graph_format,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册