“a6245433623661c6da9c9c32217a64a2ede1f87d”上不存在“modules/image/classification/resnet50_vd_dishes/module.py”
提交 be4b81bd 编写于 作者: 李寅

Remove prequantize, and make autoquantize

上级 fc7a469c
......@@ -32,7 +32,7 @@ def main(unused_args):
if FLAGS.runtime == 'dsp':
output_graph_def = tf_dsp_converter_lib.convert_to_mace_pb(
input_graph_def, FLAGS.input_node, FLAGS.output_node, FLAGS.prequantize)
input_graph_def, FLAGS.input_node, FLAGS.output_node)
else:
output_graph_def = tf_converter_lib.convert_to_mace_pb(
input_graph_def, FLAGS.input_node, FLAGS.output_node, FLAGS.data_type, FLAGS.runtime)
......@@ -85,11 +85,6 @@ def parse_args():
type=str,
default="softmax",
help="e.g., softmax")
parser.add_argument(
"--prequantize",
type=bool,
default=True,
help="e.g., True")
parser.add_argument(
"--data_type",
type=str,
......
......@@ -288,7 +288,7 @@ def add_input_output_info(net_def, input_node, output_node, graph, dtype):
return net_def
def strip_input_quantize_and_output_dequantize(net_def, input_node, output_node):
def fuse_quantize(net_def, input_node, output_node):
tensor_map = {}
for tensor in net_def.tensors:
tensor_map[tensor.name] = tensor
......@@ -319,42 +319,10 @@ def strip_input_quantize_and_output_dequantize(net_def, input_node, output_node)
quantize_op = o
if quantize_op is not None:
minf_op, maxf_op = consumers[get_tensor_name_from_op(flatten_op.name, 0)]
skip_ops = skip_ops.union([input_op.name, flatten_op.name, minf_op.name, maxf_op.name, quantize_op.name])
skip_ops = skip_ops.union([flatten_op.name, minf_op.name, maxf_op.name])
skip_tensors = skip_tensors.union([flatten_op.input[1], minf_op.input[1], maxf_op.input[1]])
new_input_op = mace_pb2.OperatorDef()
new_input_op.name = input_op.name
new_input_op.type = input_op.type
new_input_op.padding = input_op.padding
new_input_op.out_max_byte_size.extend([input_op.out_max_byte_size[0]/4, 4, 4])
new_ops.append(new_input_op)
new_input_op.output_shape.extend([input_op.output_shape[0],
minf_op.output_shape[0],
maxf_op.output_shape[0]])
new_input_op.output_type.extend([input_op.output_type[0], mace_pb2.DT_FLOAT, mace_pb2.DT_FLOAT])
for follow_op in consumers[get_tensor_name_from_op(quantize_op.name, 0)]:
new_follow_op = mace_pb2.OperatorDef()
new_follow_op.CopyFrom(follow_op)
for i in xrange(len(follow_op.input)):
for k in xrange(3):
if new_follow_op.input[i] == get_tensor_name_from_op(quantize_op.name, k):
new_follow_op.input[i] = get_tensor_name_from_op(input_op.name, k)
new_ops.append(new_follow_op)
skip_ops.add(follow_op.name)
elif op.type == 'OUTPUT':
output_op = op
dequantize_op = get_node_from_map(op_map, output_op.input[0])
if dequantize_op.type == 'Dequantize':
skip_ops = skip_ops.union([dequantize_op.name, output_op.name])
new_output_op = mace_pb2.OperatorDef()
new_output_op.name = output_op.name
new_output_op.type = output_op.type
new_output_op.input.extend(dequantize_op.input)
new_ops.append(new_output_op)
quantize_op.type = 'AutoQuantize'
del quantize_op.input[1:]
new_net_def = mace_pb2.NetDef()
new_net_def.tensors.extend([tensor for tensor in net_def.tensors if tensor.name not in skip_tensors])
......@@ -362,7 +330,7 @@ def strip_input_quantize_and_output_dequantize(net_def, input_node, output_node)
new_net_def.op.extend(new_ops)
return new_net_def
def convert_to_mace_pb(input_graph_def, input_node, output_node, prequantize=False):
def convert_to_mace_pb(input_graph_def, input_node, output_node):
"""
nnlib does not have batch norm, so use tensorflow optimizer to fold
batch norm with convolution. The fold optimization reorders ops, so
......@@ -388,19 +356,13 @@ def convert_to_mace_pb(input_graph_def, input_node, output_node, prequantize=Fal
convert_ops(unresolved_ops, resolved_ops, net_def, output_node, dsp_ops)
add_output_node(net_def, output_node)
# optimized_net_def = reverse_batch_to_space_and_biasadd(net_def)
if prequantize:
print('Prequantize ...')
net_def = strip_input_quantize_and_output_dequantize(net_def, input_node, output_node)
net_def = reverse_batch_to_space_and_biasadd(net_def)
net_def = fuse_quantize(net_def, input_node, output_node)
sorted_net_def = graph_util.sort_mace_graph(net_def, '__output__')
net_def_with_node_id = add_node_id(sorted_net_def)
if prequantize:
dtype = mace_pb2.DT_UINT8
else:
dtype = mace_pb2.DT_FLOAT
dtype = mace_pb2.DT_FLOAT
final_net_def = add_input_output_info(net_def_with_node_id, input_node, output_node, graph, dtype)
return final_net_def
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册