提交 76004dd7 编写于 作者: 卢旭辉

Merge branch 'validate' into 'master'

Support layers validate for HTA

See merge request !1222
......@@ -70,31 +70,36 @@ def convert(model_file, output_dir, layers):
with open(model_file, "rb") as f:
net_def.ParseFromString(f.read())
quantize_flag = ConverterUtil.get_arg(
is_quantize = ConverterUtil.get_arg(
net_def, MaceKeyword.mace_quantize_flag_arg_str)
quantize_flag = False if quantize_flag is None else quantize_flag.i == 1
hexagon_flag = False
is_quantize = False if is_quantize is None else is_quantize.i == 1
is_hexagon = False
index = 0
end_index = len(net_def.op)
if quantize_flag:
if is_quantize:
while index < end_index:
# omit op quantize
if net_def.op[index].type == MaceOp.Quantize.name or \
net_def.op[index].type == \
HexagonOp.QuantizeINPUT_f_to_8.name:
HexagonOp.QuantizeINPUT_f_to_8.name or \
net_def.op[index].type == HexagonOp.INPUT.name:
index += 1
# omit op dequantize
elif net_def.op[end_index - 1].type == MaceOp.Dequantize.name or \
net_def.op[end_index - 1].type == \
HexagonOp.DequantizeOUTPUT_8tof.name:
HexagonOp.DequantizeOUTPUT_8tof.name or \
net_def.op[end_index - 1].type == HexagonOp.OUTPUT.name:
end_index -= 1
else:
break
mace_check(0 < index < end_index < len(net_def.op),
"Wrong number of op quantize(%d) or dequantize(%d)." %
(index, len(net_def.op) - end_index))
if net_def.op[-1].type == HexagonOp.DequantizeOUTPUT_8tof.name:
hexagon_flag = True
if net_def.op[-1].type == HexagonOp.DequantizeOUTPUT_8tof.name or \
net_def.op[-1].type == HexagonOp.OUTPUT.name:
is_hexagon = True
# omit original output
end_index -= 1
......@@ -112,7 +117,7 @@ def convert(model_file, output_dir, layers):
index += 1
continue
net = copy.deepcopy(net_def)
if hexagon_flag:
if is_hexagon:
# reuse dequantize op and it's min/max tensor's node_id
del net.op[index+1:-1]
else:
......@@ -124,9 +129,9 @@ def convert(model_file, output_dir, layers):
output_tensors = []
output_shapes = []
op_name = op.name
if quantize_flag:
if is_quantize:
op.name = MaceKeyword.mace_output_node_name + '_' + op.name
if hexagon_flag:
if is_hexagon:
mace_check(len(op.output) == 1,
"Only supports number of outputs of Hexagon op be 1.")
for i in range(len(op.output)):
......@@ -139,13 +144,15 @@ def convert(model_file, output_dir, layers):
output_info.data_format = data_format
output_info.dims.extend(op.output_shape[i].dims)
output_info.data_type = mace_pb2.DT_FLOAT
output_info.scale = op.quantize_info[0].scale
output_info.zero_point = op.quantize_info[0].zero_point
# modify output op
if quantize_flag:
if is_quantize:
output_name = op.output[i]
new_output_name = \
MaceKeyword.mace_output_node_name + '_' + op.output[i]
op.output[i] = new_output_name
if not hexagon_flag:
if not is_hexagon:
dequantize_op = net.op.add()
dequantize_op.name = normalize_op_name(output_name)
dequantize_op.type = MaceOp.Dequantize.name
......@@ -162,14 +169,18 @@ def convert(model_file, output_dir, layers):
del dequantize_op.input[:]
del dequantize_op.output[:]
dequantize_op.input.append(new_output_name)
dequantize_op.output.append(output_name)
input_min = new_output_name[:-1] + '1'
input_max = new_output_name[:-1] + '2'
dequantize_op.input.extend([input_min, input_max])
dequantize_op.node_input[0].node_id = op.node_id
dequantize_op.node_input[1].node_id = op.node_id
dequantize_op.node_input[2].node_id = op.node_id
del dequantize_op.node_input[3:]
dequantize_op.output.append(output_name)
if dequantize_op.type == \
HexagonOp.DequantizeOUTPUT_8tof.name:
input_min = new_output_name[:-1] + '1'
input_max = new_output_name[:-1] + '2'
dequantize_op.input.extend([input_min, input_max])
dequantize_op.node_input[1].node_id = op.node_id
dequantize_op.node_input[2].node_id = op.node_id
del dequantize_op.node_input[3:]
else:
del dequantize_op.node_input[1:]
model_path = save_model_to_proto(net, normalize_op_name(op_name),
output_dir)
......
......@@ -939,8 +939,11 @@ class Transformer(base_converter.ConverterInterface):
# update output shape
conv_op.output_shape[0].dims[:] = \
b2s_op.output_shape[0].dims[:]
conv_op.output[0] = b2s_op.output[0]
conv_op.name = b2s_op.name
self.safe_remove_node(op, None)
self.replace_quantize_info(b2s_op, conv_op)
self.safe_remove_node(b2s_op, conv_op)
return True
return False
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册