提交 76004dd7 编写于 作者: 卢旭辉

Merge branch 'validate' into 'master'

Support layers validate for HTA

See merge request !1222
...@@ -70,31 +70,36 @@ def convert(model_file, output_dir, layers): ...@@ -70,31 +70,36 @@ def convert(model_file, output_dir, layers):
with open(model_file, "rb") as f: with open(model_file, "rb") as f:
net_def.ParseFromString(f.read()) net_def.ParseFromString(f.read())
quantize_flag = ConverterUtil.get_arg( is_quantize = ConverterUtil.get_arg(
net_def, MaceKeyword.mace_quantize_flag_arg_str) net_def, MaceKeyword.mace_quantize_flag_arg_str)
quantize_flag = False if quantize_flag is None else quantize_flag.i == 1 is_quantize = False if is_quantize is None else is_quantize.i == 1
hexagon_flag = False is_hexagon = False
index = 0 index = 0
end_index = len(net_def.op) end_index = len(net_def.op)
if quantize_flag: if is_quantize:
while index < end_index: while index < end_index:
# omit op quantize # omit op quantize
if net_def.op[index].type == MaceOp.Quantize.name or \ if net_def.op[index].type == MaceOp.Quantize.name or \
net_def.op[index].type == \ net_def.op[index].type == \
HexagonOp.QuantizeINPUT_f_to_8.name: HexagonOp.QuantizeINPUT_f_to_8.name or \
net_def.op[index].type == HexagonOp.INPUT.name:
index += 1 index += 1
# omit op dequantize # omit op dequantize
elif net_def.op[end_index - 1].type == MaceOp.Dequantize.name or \ elif net_def.op[end_index - 1].type == MaceOp.Dequantize.name or \
net_def.op[end_index - 1].type == \ net_def.op[end_index - 1].type == \
HexagonOp.DequantizeOUTPUT_8tof.name: HexagonOp.DequantizeOUTPUT_8tof.name or \
net_def.op[end_index - 1].type == HexagonOp.OUTPUT.name:
end_index -= 1 end_index -= 1
else: else:
break break
mace_check(0 < index < end_index < len(net_def.op), mace_check(0 < index < end_index < len(net_def.op),
"Wrong number of op quantize(%d) or dequantize(%d)." % "Wrong number of op quantize(%d) or dequantize(%d)." %
(index, len(net_def.op) - end_index)) (index, len(net_def.op) - end_index))
if net_def.op[-1].type == HexagonOp.DequantizeOUTPUT_8tof.name: if net_def.op[-1].type == HexagonOp.DequantizeOUTPUT_8tof.name or \
hexagon_flag = True net_def.op[-1].type == HexagonOp.OUTPUT.name:
is_hexagon = True
# omit original output # omit original output
end_index -= 1 end_index -= 1
...@@ -112,7 +117,7 @@ def convert(model_file, output_dir, layers): ...@@ -112,7 +117,7 @@ def convert(model_file, output_dir, layers):
index += 1 index += 1
continue continue
net = copy.deepcopy(net_def) net = copy.deepcopy(net_def)
if hexagon_flag: if is_hexagon:
# reuse dequantize op and it's min/max tensor's node_id # reuse dequantize op and it's min/max tensor's node_id
del net.op[index+1:-1] del net.op[index+1:-1]
else: else:
...@@ -124,9 +129,9 @@ def convert(model_file, output_dir, layers): ...@@ -124,9 +129,9 @@ def convert(model_file, output_dir, layers):
output_tensors = [] output_tensors = []
output_shapes = [] output_shapes = []
op_name = op.name op_name = op.name
if quantize_flag: if is_quantize:
op.name = MaceKeyword.mace_output_node_name + '_' + op.name op.name = MaceKeyword.mace_output_node_name + '_' + op.name
if hexagon_flag: if is_hexagon:
mace_check(len(op.output) == 1, mace_check(len(op.output) == 1,
"Only supports number of outputs of Hexagon op be 1.") "Only supports number of outputs of Hexagon op be 1.")
for i in range(len(op.output)): for i in range(len(op.output)):
...@@ -139,13 +144,15 @@ def convert(model_file, output_dir, layers): ...@@ -139,13 +144,15 @@ def convert(model_file, output_dir, layers):
output_info.data_format = data_format output_info.data_format = data_format
output_info.dims.extend(op.output_shape[i].dims) output_info.dims.extend(op.output_shape[i].dims)
output_info.data_type = mace_pb2.DT_FLOAT output_info.data_type = mace_pb2.DT_FLOAT
output_info.scale = op.quantize_info[0].scale
output_info.zero_point = op.quantize_info[0].zero_point
# modify output op # modify output op
if quantize_flag: if is_quantize:
output_name = op.output[i] output_name = op.output[i]
new_output_name = \ new_output_name = \
MaceKeyword.mace_output_node_name + '_' + op.output[i] MaceKeyword.mace_output_node_name + '_' + op.output[i]
op.output[i] = new_output_name op.output[i] = new_output_name
if not hexagon_flag: if not is_hexagon:
dequantize_op = net.op.add() dequantize_op = net.op.add()
dequantize_op.name = normalize_op_name(output_name) dequantize_op.name = normalize_op_name(output_name)
dequantize_op.type = MaceOp.Dequantize.name dequantize_op.type = MaceOp.Dequantize.name
...@@ -162,14 +169,18 @@ def convert(model_file, output_dir, layers): ...@@ -162,14 +169,18 @@ def convert(model_file, output_dir, layers):
del dequantize_op.input[:] del dequantize_op.input[:]
del dequantize_op.output[:] del dequantize_op.output[:]
dequantize_op.input.append(new_output_name) dequantize_op.input.append(new_output_name)
dequantize_op.output.append(output_name)
input_min = new_output_name[:-1] + '1'
input_max = new_output_name[:-1] + '2'
dequantize_op.input.extend([input_min, input_max])
dequantize_op.node_input[0].node_id = op.node_id dequantize_op.node_input[0].node_id = op.node_id
dequantize_op.node_input[1].node_id = op.node_id dequantize_op.output.append(output_name)
dequantize_op.node_input[2].node_id = op.node_id if dequantize_op.type == \
del dequantize_op.node_input[3:] HexagonOp.DequantizeOUTPUT_8tof.name:
input_min = new_output_name[:-1] + '1'
input_max = new_output_name[:-1] + '2'
dequantize_op.input.extend([input_min, input_max])
dequantize_op.node_input[1].node_id = op.node_id
dequantize_op.node_input[2].node_id = op.node_id
del dequantize_op.node_input[3:]
else:
del dequantize_op.node_input[1:]
model_path = save_model_to_proto(net, normalize_op_name(op_name), model_path = save_model_to_proto(net, normalize_op_name(op_name),
output_dir) output_dir)
......
...@@ -939,8 +939,11 @@ class Transformer(base_converter.ConverterInterface): ...@@ -939,8 +939,11 @@ class Transformer(base_converter.ConverterInterface):
# update output shape # update output shape
conv_op.output_shape[0].dims[:] = \ conv_op.output_shape[0].dims[:] = \
b2s_op.output_shape[0].dims[:] b2s_op.output_shape[0].dims[:]
conv_op.output[0] = b2s_op.output[0]
conv_op.name = b2s_op.name
self.safe_remove_node(op, None) self.safe_remove_node(op, None)
self.replace_quantize_info(b2s_op, conv_op)
self.safe_remove_node(b2s_op, conv_op) self.safe_remove_node(b2s_op, conv_op)
return True return True
return False return False
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册