提交 be4b81bd 编写于 作者: 李寅

Remove prequantize, and make autoquantize

上级 fc7a469c
...@@ -32,7 +32,7 @@ def main(unused_args): ...@@ -32,7 +32,7 @@ def main(unused_args):
if FLAGS.runtime == 'dsp': if FLAGS.runtime == 'dsp':
output_graph_def = tf_dsp_converter_lib.convert_to_mace_pb( output_graph_def = tf_dsp_converter_lib.convert_to_mace_pb(
input_graph_def, FLAGS.input_node, FLAGS.output_node, FLAGS.prequantize) input_graph_def, FLAGS.input_node, FLAGS.output_node)
else: else:
output_graph_def = tf_converter_lib.convert_to_mace_pb( output_graph_def = tf_converter_lib.convert_to_mace_pb(
input_graph_def, FLAGS.input_node, FLAGS.output_node, FLAGS.data_type, FLAGS.runtime) input_graph_def, FLAGS.input_node, FLAGS.output_node, FLAGS.data_type, FLAGS.runtime)
...@@ -85,11 +85,6 @@ def parse_args(): ...@@ -85,11 +85,6 @@ def parse_args():
type=str, type=str,
default="softmax", default="softmax",
help="e.g., softmax") help="e.g., softmax")
parser.add_argument(
"--prequantize",
type=bool,
default=True,
help="e.g., True")
parser.add_argument( parser.add_argument(
"--data_type", "--data_type",
type=str, type=str,
......
...@@ -288,7 +288,7 @@ def add_input_output_info(net_def, input_node, output_node, graph, dtype): ...@@ -288,7 +288,7 @@ def add_input_output_info(net_def, input_node, output_node, graph, dtype):
return net_def return net_def
def strip_input_quantize_and_output_dequantize(net_def, input_node, output_node): def fuse_quantize(net_def, input_node, output_node):
tensor_map = {} tensor_map = {}
for tensor in net_def.tensors: for tensor in net_def.tensors:
tensor_map[tensor.name] = tensor tensor_map[tensor.name] = tensor
...@@ -319,42 +319,10 @@ def strip_input_quantize_and_output_dequantize(net_def, input_node, output_node) ...@@ -319,42 +319,10 @@ def strip_input_quantize_and_output_dequantize(net_def, input_node, output_node)
quantize_op = o quantize_op = o
if quantize_op is not None: if quantize_op is not None:
minf_op, maxf_op = consumers[get_tensor_name_from_op(flatten_op.name, 0)] minf_op, maxf_op = consumers[get_tensor_name_from_op(flatten_op.name, 0)]
skip_ops = skip_ops.union([input_op.name, flatten_op.name, minf_op.name, maxf_op.name, quantize_op.name]) skip_ops = skip_ops.union([flatten_op.name, minf_op.name, maxf_op.name])
skip_tensors = skip_tensors.union([flatten_op.input[1], minf_op.input[1], maxf_op.input[1]]) skip_tensors = skip_tensors.union([flatten_op.input[1], minf_op.input[1], maxf_op.input[1]])
quantize_op.type = 'AutoQuantize'
new_input_op = mace_pb2.OperatorDef() del quantize_op.input[1:]
new_input_op.name = input_op.name
new_input_op.type = input_op.type
new_input_op.padding = input_op.padding
new_input_op.out_max_byte_size.extend([input_op.out_max_byte_size[0]/4, 4, 4])
new_ops.append(new_input_op)
new_input_op.output_shape.extend([input_op.output_shape[0],
minf_op.output_shape[0],
maxf_op.output_shape[0]])
new_input_op.output_type.extend([input_op.output_type[0], mace_pb2.DT_FLOAT, mace_pb2.DT_FLOAT])
for follow_op in consumers[get_tensor_name_from_op(quantize_op.name, 0)]:
new_follow_op = mace_pb2.OperatorDef()
new_follow_op.CopyFrom(follow_op)
for i in xrange(len(follow_op.input)):
for k in xrange(3):
if new_follow_op.input[i] == get_tensor_name_from_op(quantize_op.name, k):
new_follow_op.input[i] = get_tensor_name_from_op(input_op.name, k)
new_ops.append(new_follow_op)
skip_ops.add(follow_op.name)
elif op.type == 'OUTPUT':
output_op = op
dequantize_op = get_node_from_map(op_map, output_op.input[0])
if dequantize_op.type == 'Dequantize':
skip_ops = skip_ops.union([dequantize_op.name, output_op.name])
new_output_op = mace_pb2.OperatorDef()
new_output_op.name = output_op.name
new_output_op.type = output_op.type
new_output_op.input.extend(dequantize_op.input)
new_ops.append(new_output_op)
new_net_def = mace_pb2.NetDef() new_net_def = mace_pb2.NetDef()
new_net_def.tensors.extend([tensor for tensor in net_def.tensors if tensor.name not in skip_tensors]) new_net_def.tensors.extend([tensor for tensor in net_def.tensors if tensor.name not in skip_tensors])
...@@ -362,7 +330,7 @@ def strip_input_quantize_and_output_dequantize(net_def, input_node, output_node) ...@@ -362,7 +330,7 @@ def strip_input_quantize_and_output_dequantize(net_def, input_node, output_node)
new_net_def.op.extend(new_ops) new_net_def.op.extend(new_ops)
return new_net_def return new_net_def
def convert_to_mace_pb(input_graph_def, input_node, output_node, prequantize=False): def convert_to_mace_pb(input_graph_def, input_node, output_node):
""" """
nnlib does not have batch norm, so use tensorflow optimizer to fold nnlib does not have batch norm, so use tensorflow optimizer to fold
batch norm with convolution. The fold optimization reorders ops, so batch norm with convolution. The fold optimization reorders ops, so
...@@ -388,19 +356,13 @@ def convert_to_mace_pb(input_graph_def, input_node, output_node, prequantize=Fal ...@@ -388,19 +356,13 @@ def convert_to_mace_pb(input_graph_def, input_node, output_node, prequantize=Fal
convert_ops(unresolved_ops, resolved_ops, net_def, output_node, dsp_ops) convert_ops(unresolved_ops, resolved_ops, net_def, output_node, dsp_ops)
add_output_node(net_def, output_node) add_output_node(net_def, output_node)
# optimized_net_def = reverse_batch_to_space_and_biasadd(net_def) net_def = reverse_batch_to_space_and_biasadd(net_def)
net_def = fuse_quantize(net_def, input_node, output_node)
if prequantize:
print('Prequantize ...')
net_def = strip_input_quantize_and_output_dequantize(net_def, input_node, output_node)
sorted_net_def = graph_util.sort_mace_graph(net_def, '__output__') sorted_net_def = graph_util.sort_mace_graph(net_def, '__output__')
net_def_with_node_id = add_node_id(sorted_net_def) net_def_with_node_id = add_node_id(sorted_net_def)
if prequantize: dtype = mace_pb2.DT_FLOAT
dtype = mace_pb2.DT_UINT8
else:
dtype = mace_pb2.DT_FLOAT
final_net_def = add_input_output_info(net_def_with_node_id, input_node, output_node, graph, dtype) final_net_def = add_input_output_info(net_def_with_node_id, input_node, output_node, graph, dtype)
return final_net_def return final_net_def
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册