提交 0c6ec590 编写于 作者: 李寅

Move quantize info into op's output

上级 57a3298d
......@@ -109,7 +109,9 @@ class Tensor {
name_(""),
is_weight_(is_weight),
scale_(0.f),
zero_point_(0) {}
zero_point_(0),
minval_(0.f),
maxval_(0.f) {}
Tensor(BufferBase *buffer, DataType dtype,
bool is_weight = false)
......@@ -120,7 +122,9 @@ class Tensor {
name_(""),
is_weight_(is_weight),
scale_(0.f),
zero_point_(0) {}
zero_point_(0),
minval_(0.f),
maxval_(0.f) {}
Tensor(const BufferSlice &buffer_slice,
DataType dtype,
......@@ -132,7 +136,9 @@ class Tensor {
name_(""),
is_weight_(is_weight),
scale_(0.f),
zero_point_(0) {
zero_point_(0),
minval_(0.f),
maxval_(0.f) {
buffer_ = &buffer_slice_;
}
......@@ -391,6 +397,15 @@ class Tensor {
return zero_point_;
}
// hexagon now uses min/max instead of scale and zero
inline float minval() const {
return minval_;
}
inline float maxval() const {
return maxval_;
}
inline void SetScale(float scale) {
scale_ = scale;
}
......@@ -403,6 +418,14 @@ class Tensor {
is_weight_ = is_weight;
}
inline void SetMinVal(float minval) {
minval_ = minval;
}
inline void SetMaxVal(float maxval) {
maxval_ = maxval;
}
private:
Allocator *allocator_;
DataType dtype_;
......@@ -416,6 +439,8 @@ class Tensor {
bool is_weight_;
float scale_;
int32_t zero_point_;
float minval_;
float maxval_;
MACE_DISABLE_COPY_AND_ASSIGN(Tensor);
};
......
......@@ -178,16 +178,19 @@ MaceStatus Workspace::LoadModelTensor(const NetDef &net_def,
if (status != MaceStatus::MACE_SUCCESS) return status;
}
if (device_type == DeviceType::CPU && net_def.has_quantize_info()) {
for (const auto
&activation_info: net_def.quantize_info().activation_info()) {
if (HasTensor(activation_info.tensor_name())) {
Tensor *tensor = GetTensor(activation_info.tensor_name());
tensor->SetScale(activation_info.scale());
tensor->SetZeroPoint(activation_info.zero_point());
} else {
LOG(WARNING) << "Quantize info exists for non-existed tensor: "
<< activation_info.tensor_name();
if (device_type == DeviceType::CPU) {
for (const auto &op : net_def.op()) {
VLOG(2) << "Add quantize info for op: " << op.name();
MACE_CHECK(op.quantize_info().empty()
|| op.quantize_info().size() == op.output().size(),
"quantize info size must be equal to output size or empty");
for (int i = 0; i < op.quantize_info().size(); ++i) {
auto &quantize_info = op.quantize_info(i);
Tensor *tensor = GetTensor(op.output(i));
tensor->SetScale(quantize_info.scale());
tensor->SetZeroPoint(quantize_info.zero_point());
tensor->SetMinVal(quantize_info.minval());
tensor->SetMaxVal(quantize_info.maxval());
}
}
}
......@@ -233,8 +236,7 @@ MaceStatus Workspace::CreateOutputTensorBuffer(const NetDef &net_def,
std::unique_ptr<BufferBase> tensor_buf(
new Buffer(GetCPUAllocator()));
MACE_RETURN_IF_ERROR(tensor_buf->Allocate(
mem_block.x() * GetEnumTypeSize(dtype)
+ MACE_EXTRA_BUFFER_PAD_SIZE));
mem_block.x() + MACE_EXTRA_BUFFER_PAD_SIZE));
preallocated_allocator_.SetBuffer(mem_block.mem_id(),
std::move(tensor_buf));
} else if (mem_block.mem_type() == MemoryType::GPU_IMAGE) {
......
......@@ -879,6 +879,17 @@ struct Conv2dFunctor<DeviceType::CPU, uint8_t> : Conv2dFunctorBase {
const index_t depth = input_channels * filter_h * filter_w;
const index_t columns = batch * height * width;
VLOG(2) << "input scale/zero: " << input->scale() << ", "
<< input->zero_point();
VLOG(2) << "filter scale/zero: " << filter->scale() << ", "
<< filter->zero_point();
if (bias) {
VLOG(2) << "bias scale/zero: " << bias->scale() << ", "
<< bias->zero_point();
}
VLOG(2) << "output scale/zero: " << output->scale() << ", "
<< output->zero_point();
MACE_CHECK(filter->dim(0) == channels, filter->dim(0), " != ", channels);
MACE_CHECK(filter->dim(3) == input_channels, filter->dim(3), " != ",
input_channels);
......
......@@ -59,6 +59,13 @@ message OutputShape {
repeated int64 dims = 1;
}
message QuantizeActivationInfo {
optional float scale = 1;
optional int32 zero_point = 2;
optional float minval = 3; // hexagon uses min/max
optional float maxval = 4;
}
message OperatorDef {
repeated string input = 1;
repeated string output = 2;
......@@ -67,6 +74,7 @@ message OperatorDef {
repeated Argument arg = 5;
repeated OutputShape output_shape = 6;
repeated DataType output_type = 7;
repeated QuantizeActivationInfo quantize_info = 8;
repeated int32 mem_id = 10;
......@@ -106,23 +114,12 @@ message OutputInfo {
optional DataType data_type = 5 [default = DT_FLOAT];
}
message QuantizeActivationInfo {
optional string tensor_name = 1;
optional float scale = 2;
optional int32 zero_point = 3;
}
message QuantizeInfo {
repeated QuantizeActivationInfo activation_info = 1;
}
message NetDef {
optional string name = 1;
repeated OperatorDef op = 2;
optional string version = 3;
repeated Argument arg = 4;
repeated ConstTensor tensors = 5;
optional QuantizeInfo quantize_info = 6;
// for mem optimization
optional MemoryArena mem_arena = 10;
......
......@@ -64,6 +64,10 @@ def parse_int_array_from_str(ints_str):
return [int(int_str) for int_str in ints_str.split(',')]
def parse_float_array_from_str(ints_str):
return [float(int_str) for int_str in ints_str.split(',')]
def main(unused_args):
if not os.path.isfile(FLAGS.model_file):
print("Input graph file '" + FLAGS.model_file + "' does not exist!")
......@@ -105,12 +109,18 @@ def main(unused_args):
input_node_names = FLAGS.input_node.split(',')
input_node_shapes = FLAGS.input_shape.split(':')
if FLAGS.input_range:
input_node_ranges = FLAGS.input_range.split(':')
else:
input_node_ranges = []
if len(input_node_names) != len(input_node_shapes):
raise Exception('input node count and shape count do not match.')
for i in xrange(len(input_node_names)):
input_node = cvt.NodeInfo()
input_node.name = input_node_names[i]
input_node.shape = parse_int_array_from_str(input_node_shapes[i])
if len(input_node_ranges) > i:
input_node.range = parse_float_array_from_str(input_node_ranges[i])
option.add_input_node(input_node)
output_node_names = FLAGS.output_node.split(',')
......@@ -276,6 +286,8 @@ def parse_args():
"--dsp_mode", type=int, default=0, help="dsp run mode, defalut=0")
parser.add_argument(
"--input_shape", type=str, default="", help="input shape.")
parser.add_argument(
"--input_range", type=str, default="", help="input range.")
parser.add_argument(
"--platform", type=str, default="tensorflow", help="tensorflow/caffe")
parser.add_argument(
......
......@@ -203,6 +203,8 @@ class TransformerRule(Enum):
QUANTIZE_WEIGHTS = 25
TRANSFORM_LSTMCELL_ZEROSTATE = 26
TRANSFORM_BASIC_LSTMCELL = 27
TRANSFORM_FAKE_QUANTIZE = 28
CHECK_QUANTIZE_INFO = 29
class ConverterInterface(object):
......@@ -218,6 +220,7 @@ class NodeInfo(object):
def __init__(self):
self._name = None
self._shape = []
self._range = [-1.0, 1.0]
@property
def name(self):
......@@ -227,6 +230,10 @@ class NodeInfo(object):
def shape(self):
return self._shape
@property
def range(self):
return self._range
@name.setter
def name(self, name):
self._name = name
......@@ -235,6 +242,10 @@ class NodeInfo(object):
def shape(self, shape):
self._shape = shape
@range.setter
def range(self, range):
self._range = range
def __str__(self):
return '%s %s' % (self._name, str(self._shape))
......@@ -339,6 +350,7 @@ class ConverterOption(object):
else:
self._transformer_option = [
# Model structure related transformation
TransformerRule.TRANSFORM_FAKE_QUANTIZE,
TransformerRule.REMOVE_IDENTITY_OP,
TransformerRule.TRANSFORM_GLOBAL_POOLING,
TransformerRule.TRANSFORM_LSTMCELL_ZEROSTATE,
......@@ -368,15 +380,17 @@ class ConverterOption(object):
# Transform finalization
TransformerRule.ADD_MACE_INPUT_AND_OUTPUT_NODES,
# for quantization entropy calibration use
TransformerRule.ADD_QUANTIZE_TENSOR_RANGE,
TransformerRule.SORT_BY_EXECUTION,
# Need to be put after SORT_BY_EXECUTION
TransformerRule.ADD_QUANTIZE_TENSOR_RANGE,
]
if self._quantize:
self._transformer_option = self._transformer_option[:-1] + [
self._transformer_option = self._transformer_option + [
# need to be put after ADD_QUANTIZE_TENSOR_RANGE
TransformerRule.QUANTIZE_NODES,
TransformerRule.ADD_QUANTIZE_TENSOR_RANGE,
TransformerRule.QUANTIZE_WEIGHTS,
TransformerRule.SORT_BY_EXECUTION,
TransformerRule.CHECK_QUANTIZE_INFO,
]
......
......@@ -102,6 +102,7 @@ TFSupportedOps = [
'Cast',
'ArgMax',
'Split',
'FakeQuantWithMinMaxVars',
]
TFOpType = Enum('TFOpType', [(op, op) for op in TFSupportedOps], type=str)
......@@ -205,6 +206,7 @@ class TensorflowConverter(base_converter.ConverterInterface):
TFOpType.Cast.name: self.convert_cast,
TFOpType.ArgMax.name: self.convert_argmax,
TFOpType.Split.name: self.convert_split,
TFOpType.FakeQuantWithMinMaxVars.name: self.convert_fake_quantize,
}
self._option = option
self._mace_net_def = mace_pb2.NetDef()
......@@ -841,3 +843,21 @@ class TensorflowConverter(base_converter.ConverterInterface):
num_split_arg.i = tf_op.get_attr('num_split')
self._skip_tensor.add(tf_op.inputs[0].name)
def convert_fake_quantize(self, tf_op):
op = self.convert_general_op(tf_op)
min_arg = op.arg.add()
min_arg.name = 'min'
min_arg.f = tf_op.inputs[1].eval()
max_arg = op.arg.add()
max_arg.name = 'max'
max_arg.f = tf_op.inputs[2].eval()
narrow_range_arg = op.arg.add()
narrow_range_arg.name = 'narrow_range'
narrow_range_arg.i = int(tf_op.get_attr('narrow_range'))
num_bits_arg = op.arg.add()
num_bits_arg.name = 'num_bits'
num_bits_arg.i = int(tf_op.get_attr('num_bits'))
self._skip_tensor.add(tf_op.inputs[1].name)
self._skip_tensor.add(tf_op.inputs[2].name)
......@@ -47,6 +47,8 @@ class Transformer(base_converter.ConverterInterface):
# Dependencies
# (TRANSFORM_MATMUL_TO_FC, TRANSFORM_GLOBAL_CONV_TO_FC) -> RESHAPE_FC_WEIGHT # noqa
self._registered_transformers = {
TransformerRule.TRANSFORM_FAKE_QUANTIZE:
self.transform_fake_quantize,
TransformerRule.REMOVE_IDENTITY_OP: self.remove_identity_op,
TransformerRule.TRANSFORM_GLOBAL_POOLING:
self.transform_global_pooling,
......@@ -91,6 +93,8 @@ class Transformer(base_converter.ConverterInterface):
TransformerRule.ADD_MACE_INPUT_AND_OUTPUT_NODES:
self.add_mace_input_and_output_nodes,
TransformerRule.SORT_BY_EXECUTION: self.sort_by_execution,
TransformerRule.CHECK_QUANTIZE_INFO:
self.check_quantize_info,
}
self._option = option
......@@ -774,16 +778,22 @@ class Transformer(base_converter.ConverterInterface):
def transform_add_to_biasadd(self):
net = self._model
for op in net.op:
if op.type == 'Add' \
and len(op.input) == 2 \
and op.input[1] in self._consts \
and len(self._consts[op.input[1]].dims) == 1:
if (op.type == 'Eltwise'
and ConverterUtil.get_arg(op, MaceKeyword.mace_element_type_str).i == EltwiseType.SUM.value # noqa
and len(op.input) == 2
and op.input[1] in self._consts
and len(self._consts[op.input[1]].dims) == 1):
print("Transform add to biasadd: %s(%s)" % (op.name, op.type))
op.type = MaceOp.BiasAdd.name
return True
return False
def replace_quantize_info(self, op, replace_op):
if len(replace_op.quantize_info) > 0:
del op.quantize_info[:]
op.quantize_info.extend(replace_op.quantize_info)
def fold_biasadd(self):
net = self._model
for op in net.op:
......@@ -799,6 +809,7 @@ class Transformer(base_converter.ConverterInterface):
if consumer_op.type == MaceOp.BiasAdd.name:
print("Fold biasadd: %s(%s)" % (op.name, op.type))
op.input.append(consumer_op.input[1])
self.replace_quantize_info(op, consumer_op)
self.safe_remove_node(consumer_op, op)
return True
......@@ -886,6 +897,7 @@ class Transformer(base_converter.ConverterInterface):
or arg.name == MaceKeyword.mace_activation_max_limit_str: # noqa
op.arg.extend([arg])
self.replace_quantize_info(op, consumer_op)
self.safe_remove_node(consumer_op, op)
return True
......@@ -1163,7 +1175,8 @@ class Transformer(base_converter.ConverterInterface):
transposed_filter = set()
transposed_deconv_filter = set()
if self._option.quantize:
if self._option.quantize and \
self._option.device == DeviceType.CPU.value:
print("Transpose filters to OHWI")
if filter_format == FilterFormat.HWIO:
transpose_order = [3, 0, 1, 2]
......@@ -1601,6 +1614,9 @@ class Transformer(base_converter.ConverterInterface):
return False
def quantize_nodes(self):
if not self._option.quantize:
return False
print("Add mace quantize and dequantize nodes")
for op in self._model.op:
......@@ -1647,28 +1663,13 @@ class Transformer(base_converter.ConverterInterface):
self._input_output_added = True
def add_quantize_tensor_range(self):
print("Add quantize tensor range")
net = self._model
range_file = self._option.quantize_range_file
if not range_file:
return
with open(range_file) as f:
for line in f:
tensor_name, minmax = line.split("@@")
min_val, max_val = [float(i) for i in
minmax.strip().split(",")]
scale, zero = quantize_util.adjust_range(min_val, max_val,
non_zero=False)
activation_info = net.quantize_info.activation_info.add()
activation_info.tensor_name = tensor_name
activation_info.scale = scale
activation_info.zero_point = zero
self._quantize_activation_info[tensor_name] = activation_info
return False
def quantize_tensor(self, tensor):
"""Assume biasadd has been already folded with convolution and fc"""
if not self._option.quantize:
return False
if tensor.data_type == mace_pb2.DT_FLOAT:
ops = self._consumers.get(tensor.name, None)
if len(ops) == 1 and ops[0].type in [MaceOp.Conv2D.name,
......@@ -1698,8 +1699,131 @@ class Transformer(base_converter.ConverterInterface):
tensor.zero_point = quantized_tensor.zero
self._quantized_tensor.update([tensor.name])
return False
def quantize_weights(self):
print("Quantize weights")
net = self._model
for tensor in net.tensors:
self.quantize_tensor(tensor)
return False
def add_quantize_info(self, op, minval, maxval):
scale, zero = quantize_util.adjust_range(minval, maxval,
non_zero=False)
quantize_info = op.quantize_info.add()
quantize_info.minval = minval
quantize_info.maxval = maxval
quantize_info.scale = scale
quantize_info.zero_point = zero
return quantize_info
def transform_fake_quantize(self):
if not self._option.quantize:
return False
# Quantize info from fixpoint fine tune
print("Transform fake quantize")
range_file = self._option.quantize_range_file
if range_file:
return
net = self._model
for op in net.op:
if op.type == 'FakeQuantWithMinMaxVars':
producer_op = self._producer[op.input[0]]
minval = ConverterUtil.get_arg(op, 'min').f
maxval = ConverterUtil.get_arg(op, 'max').f
quantize_info = \
self.add_quantize_info(producer_op, minval, maxval)
self._quantize_activation_info[op.input[0]] = quantize_info
op.type = MaceOp.Identity.name
return False
def add_quantize_tensor_range(self):
if not self._option.quantize:
return False
# Quantize info from range statistics
print("Add quantize tensor range")
range_file = self._option.quantize_range_file
if range_file:
with open(range_file) as f:
for line in f:
tensor_name, minmax = line.split("@@")
min_val, max_val = [float(i) for i in
minmax.strip().split(",")]
scale, zero = quantize_util.adjust_range(min_val, max_val,
non_zero=False)
activation_info = mace_pb2.QuantizeActivationInfo()
activation_info.minval = min_val
activation_info.maxval = max_val
activation_info.scale = scale
activation_info.zero_point = zero
self._quantize_activation_info[tensor_name] = activation_info # noqa
for op in self._model.op:
if op.name.find(MaceKeyword.mace_output_node_name) >= 0:
continue
for output in op.output:
mace_check(output in self._quantize_activation_info,
"%s does not have quantize activation info"
% op)
op.quantize_info.extend([
self._quantize_activation_info[output]
for output in op.output])
print ("Add default quantize info for ops like Pooling, Softmax")
for op in self._model.op:
if op.type in [MaceOp.Pooling.name,
MaceOp.Squeeze.name,
MaceOp.Concat.name,
MaceOp.ResizeBilinear.name,
MaceOp.BatchToSpaceND.name,
MaceOp.SpaceToBatchND.name]:
del op.quantize_info[:]
producer_op = self._producer[op.input[0]]
quantize_info = op.quantize_info.add()
quantize_info.minval = producer_op.quantize_info[0].minval
quantize_info.maxval = producer_op.quantize_info[0].maxval
quantize_info.scale = producer_op.quantize_info[0].scale
quantize_info.zero_point = \
producer_op.quantize_info[0].zero_point
self._quantize_activation_info[op.output[0]] = quantize_info
elif op.type == MaceOp.Softmax.name:
del op.quantize_info[:]
quantize_info = \
self.add_quantize_info(op, 0.0, 1.0)
self._quantize_activation_info[op.output[0]] = quantize_info
print ("Add default quantize info for input")
for input_node in self._option.input_nodes.values():
if input_node.name not in self._quantize_activation_info:
print("Input range %s: %s" % (input_node.name,
str(input_node.range)))
scale, zero = quantize_util.adjust_range(input_node.range[0],
input_node.range[1],
non_zero=False)
quantize_info = mace_pb2.QuantizeActivationInfo()
quantize_info.minval = input_node.range[0]
quantize_info.maxval = input_node.range[1]
quantize_info.scale = scale
quantize_info.zero_point = zero
self._quantize_activation_info[input_node.name] = quantize_info
return False
def check_quantize_info(self):
if not self._option.quantize:
return False
for op in self._model.op:
if (op.name.find(MaceKeyword.mace_input_node_name) == -1
and op.name.find(MaceKeyword.mace_output_node_name) == -1
and op.type != MaceOp.Quantize.name
and op.type != MaceOp.Dequantize.name): # noqa
mace_check(len(op.output) == len(op.quantize_info),
"missing quantize info: %s" % op)
......@@ -80,9 +80,13 @@ class MemoryOptimizer(object):
def op_need_optimize_memory(self, op):
return True
def get_op_mem_block(self, op_type, output_shape):
def get_op_mem_block(self, op_type, output_shape, output_type):
data_type_size = 4
if output_type == mace_pb2.DT_UINT8:
data_type_size = 1
return MemoryBlock(mace_pb2.CPU_BUFFER,
[reduce(operator.mul, output_shape, 1)])
[reduce(operator.mul, output_shape, 1) *
data_type_size])
def mem_size(self, memory_block):
return memory_block.block[0]
......@@ -143,9 +147,13 @@ class MemoryOptimizer(object):
# make these ops reuse memory of input tensor
mem_id = self.op_mem.get(op.input[0], -1)
else:
output_type = mace_pb2.DT_FLOAT
if len(op.output_type) > i:
output_type = op.output_type[i]
op_mem_block = self.get_op_mem_block(
op.type,
op.output_shape[i].dims)
op.output_shape[i].dims,
output_type)
mem_id = -1
if len(self.idle_mem) > 0:
best_mem_add_size = sys.maxint
......@@ -221,7 +229,7 @@ class GPUMemoryOptimizer(MemoryOptimizer):
return False
return op.type != 'ImageToBuffer'
def get_op_mem_block(self, op_type, output_shape):
def get_op_mem_block(self, op_type, output_shape, output_type):
if op_type == 'WinogradTransform' or op_type == 'MatMul':
buffer_shape = list(output_shape) + [1]
mem_block = MemoryBlock(
......
......@@ -138,20 +138,6 @@ void CreateMemoryArena(mace::MemoryArena *mem_arena) {
}
{% endif %}
void AddQuantizeInfo(NetDef *net_def) {
MACE_LATENCY_LOGGER(1, "Add quantize info");
(void) net_def;
{% for i in range(net.quantize_info.activation_info|length) %}
mace::QuantizeActivationInfo *activation_info{{i}} =
net_def->mutable_quantize_info()->add_activation_info();
activation_info{{i}}->set_tensor_name("{{net.quantize_info.activation_info[i].tensor_name}}");
activation_info{{i}}->set_scale({{net.quantize_info.activation_info[i].scale}});
activation_info{{i}}->set_zero_point({{net.quantize_info.activation_info[i].zero_point}});
{% endfor %}
}
} // namespace
namespace {{tag}} {
......@@ -179,8 +165,6 @@ const std::shared_ptr<NetDef> CreateNet() {
CreateOutputInfo(net_def.get());
{% endif %}
AddQuantizeInfo(net_def.get());
return net_def;
}
......
......@@ -122,6 +122,16 @@ void CreateOperator{{i}}(mace::OperatorDef *op) {
{{ net.op[i].node_id }},
{ {{ net.op[i].mem_id | join(', ') }} });
op->mutable_quantize_info()->Reserve({{ net.op[i].quantize_info | length }});
{% for j in range(net.op[i].quantize_info|length) %}
auto quantize_info{{j}} = op->add_quantize_info();
quantize_info{{j}}->set_scale({{ net.op[i].quantize_info[j].scale }});
quantize_info{{j}}->set_zero_point({{ net.op[i].quantize_info[j].zero_point }});
quantize_info{{j}}->set_minval({{ net.op[i].quantize_info[j].minval }});
quantize_info{{j}}->set_maxval({{ net.op[i].quantize_info[j].maxval }});
{% endfor %}
{% if runtime == 'dsp' %}
op->set_padding({{ net.op[i].padding }});
{% if net.op[i].node_input | length > 0 %}
......
......@@ -760,6 +760,7 @@ def convert_model(configs):
runtime,
model_name,
":".join(subgraphs[0][YAMLKeyword.input_shapes]),
":".join(subgraphs[0][YAMLKeyword.input_ranges]),
model_config[YAMLKeyword.nnlib_graph_mode],
embed_model_data,
model_config[YAMLKeyword.winograd],
......
......@@ -551,6 +551,7 @@ def gen_model_code(model_codegen_dir,
runtime,
model_tag,
input_shapes,
input_ranges,
dsp_mode,
embed_model_data,
winograd,
......@@ -579,6 +580,7 @@ def gen_model_code(model_codegen_dir,
"--template=%s" % "mace/python/tools",
"--model_tag=%s" % model_tag,
"--input_shape=%s" % input_shapes,
"--input_range=%s" % input_ranges,
"--dsp_mode=%s" % dsp_mode,
"--embed_model_data=%s" % embed_model_data,
"--winograd=%s" % winograd,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册