diff --git a/tools/layers_validate.py b/tools/layers_validate.py index 8b8b29289dd751c7ddf07f131dbe360283a19eae..893db5e4b4188433f733d24c7cc64a881a55b91b 100644 --- a/tools/layers_validate.py +++ b/tools/layers_validate.py @@ -70,31 +70,36 @@ def convert(model_file, output_dir, layers): with open(model_file, "rb") as f: net_def.ParseFromString(f.read()) - quantize_flag = ConverterUtil.get_arg( + is_quantize = ConverterUtil.get_arg( net_def, MaceKeyword.mace_quantize_flag_arg_str) - quantize_flag = False if quantize_flag is None else quantize_flag.i == 1 - hexagon_flag = False + is_quantize = False if is_quantize is None else is_quantize.i == 1 + is_hexagon = False index = 0 end_index = len(net_def.op) - if quantize_flag: + if is_quantize: while index < end_index: # omit op quantize if net_def.op[index].type == MaceOp.Quantize.name or \ net_def.op[index].type == \ - HexagonOp.QuantizeINPUT_f_to_8.name: + HexagonOp.QuantizeINPUT_f_to_8.name or \ + net_def.op[index].type == HexagonOp.INPUT.name: index += 1 # omit op dequantize elif net_def.op[end_index - 1].type == MaceOp.Dequantize.name or \ net_def.op[end_index - 1].type == \ - HexagonOp.DequantizeOUTPUT_8tof.name: + HexagonOp.DequantizeOUTPUT_8tof.name or \ + net_def.op[end_index - 1].type == HexagonOp.OUTPUT.name: + end_index -= 1 else: break mace_check(0 < index < end_index < len(net_def.op), "Wrong number of op quantize(%d) or dequantize(%d)." % (index, len(net_def.op) - end_index)) - if net_def.op[-1].type == HexagonOp.DequantizeOUTPUT_8tof.name: - hexagon_flag = True + if net_def.op[-1].type == HexagonOp.DequantizeOUTPUT_8tof.name or \ + net_def.op[-1].type == HexagonOp.OUTPUT.name: + is_hexagon = True + # omit original output end_index -= 1 @@ -112,7 +117,7 @@ def convert(model_file, output_dir, layers): index += 1 continue net = copy.deepcopy(net_def) - if hexagon_flag: + if is_hexagon: # reuse dequantize op and it's min/max tensor's node_id del net.op[index+1:-1] else: @@ -124,9 +129,9 @@ def convert(model_file, output_dir, layers): output_tensors = [] output_shapes = [] op_name = op.name - if quantize_flag: + if is_quantize: op.name = MaceKeyword.mace_output_node_name + '_' + op.name - if hexagon_flag: + if is_hexagon: mace_check(len(op.output) == 1, "Only supports number of outputs of Hexagon op be 1.") for i in range(len(op.output)): @@ -139,13 +144,15 @@ def convert(model_file, output_dir, layers): output_info.data_format = data_format output_info.dims.extend(op.output_shape[i].dims) output_info.data_type = mace_pb2.DT_FLOAT + output_info.scale = op.quantize_info[0].scale + output_info.zero_point = op.quantize_info[0].zero_point # modify output op - if quantize_flag: + if is_quantize: output_name = op.output[i] new_output_name = \ MaceKeyword.mace_output_node_name + '_' + op.output[i] op.output[i] = new_output_name - if not hexagon_flag: + if not is_hexagon: dequantize_op = net.op.add() dequantize_op.name = normalize_op_name(output_name) dequantize_op.type = MaceOp.Dequantize.name @@ -162,14 +169,18 @@ def convert(model_file, output_dir, layers): del dequantize_op.input[:] del dequantize_op.output[:] dequantize_op.input.append(new_output_name) - dequantize_op.output.append(output_name) - input_min = new_output_name[:-1] + '1' - input_max = new_output_name[:-1] + '2' - dequantize_op.input.extend([input_min, input_max]) dequantize_op.node_input[0].node_id = op.node_id - dequantize_op.node_input[1].node_id = op.node_id - dequantize_op.node_input[2].node_id = op.node_id - del dequantize_op.node_input[3:] + dequantize_op.output.append(output_name) + if dequantize_op.type == \ + HexagonOp.DequantizeOUTPUT_8tof.name: + input_min = new_output_name[:-1] + '1' + input_max = new_output_name[:-1] + '2' + dequantize_op.input.extend([input_min, input_max]) + dequantize_op.node_input[1].node_id = op.node_id + dequantize_op.node_input[2].node_id = op.node_id + del dequantize_op.node_input[3:] + else: + del dequantize_op.node_input[1:] model_path = save_model_to_proto(net, normalize_op_name(op_name), output_dir) diff --git a/tools/python/transform/transformer.py b/tools/python/transform/transformer.py index a080c12e904a1e21fd4340313773b976b785a53d..0d8d58cf4f9689c9ffb9007c916505c664c303de 100644 --- a/tools/python/transform/transformer.py +++ b/tools/python/transform/transformer.py @@ -939,8 +939,11 @@ class Transformer(base_converter.ConverterInterface): # update output shape conv_op.output_shape[0].dims[:] = \ b2s_op.output_shape[0].dims[:] + conv_op.output[0] = b2s_op.output[0] + conv_op.name = b2s_op.name self.safe_remove_node(op, None) + self.replace_quantize_info(b2s_op, conv_op) self.safe_remove_node(b2s_op, conv_op) return True return False