提交 727cdea5 编写于 作者: Z Zhang, Guoming

Code clean

上级 ab9fe795
......@@ -192,24 +192,17 @@ def eval(args):
t = fluid.transpiler.InferenceTranspiler()
t.transpile(test_program, fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace())
# for i in test_program.current_block().ops:
# print i
# sys.exit(0)
conv_op_index = [index for index, value in enumerate(test_program.global_block().ops) if value.type == 'conv2d']
pooling_op_index = [index for index, value in enumerate(test_program.global_block().ops) if value.type == 'pool2d']
weights_var_name = []
conv_input_var_name = []
conv_output_var_name = []
# weights_channel = {}
for i in conv_op_index[1:]:
weights_var_name.append(test_program.current_block().ops[i].input('Filter')[0])
conv_input_var_name.append(test_program.current_block().ops[i].input('Input')[0])
conv_output_var_name.append(test_program.current_block().ops[i].output('Output')[0])
for i in pooling_op_index:
conv_input_var_name.append(test_program.current_block().ops[i].input('X')[0])
conv_output_var_name.append(test_program.current_block().ops[i].output('Out')[0])
not_persistable_vars = (i for i in test_program.list_vars() if not i.persistable)
back_program = test_program.clone()
for i in not_persistable_vars:
......@@ -287,6 +280,7 @@ def eval(args):
break
int8_prog = back_program.clone()
# for index, value in enumerate(conv_op_index[1:]):
# # print index,conv_input_var_name[index], ["{}_scale.input.test".format(conv_input_var_name[index])]
# int8_prog.current_block().ops[value].desc.set_input("Scale_in", ["{}_scale.input.test".format(conv_input_var_name[index])])
......@@ -325,6 +319,7 @@ def eval(args):
# for i in int8_prog.current_block().ops[quantize_pos[0] + 2:]:
# if i.type == 'conv2d' and i.input('Input')[0] == int8_prog.current_block().ops[quantize_pos[0] + 1].output('Out')[0]:
# i.desc.set_input("Input", ["conv2_quantize_tmp"])
# dequantize_pos = get_dequantization_op_pos(int8_prog)
# dequantize_tmp_var = int8_prog.current_block().create_var(
# name="dequantize_tmp_var",
......@@ -347,7 +342,7 @@ def eval(args):
# int8_prog.current_block().ops[dequantize_pos[0] + 2].desc.set_input("X", ["dequantize_tmp_var"])
#Step 3 Save the new model
print int8_prog
# print int8_prog
# for i in int8_prog.current_block().ops:
# print '********'
# print i
......@@ -367,6 +362,7 @@ def eval(args):
# dot(int8_prog)
# for i in int8_prog.current_block().ops:
# print i
print int8_prog
for batch_id, data in enumerate(val_reader()):
loss, acc1, acc5 = exe.run(int8_prog,
fetch_list=fetch_list,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册