提交 727cdea5 编写于 作者: Z Zhang, Guoming

Code clean

上级 ab9fe795
...@@ -192,24 +192,17 @@ def eval(args): ...@@ -192,24 +192,17 @@ def eval(args):
t = fluid.transpiler.InferenceTranspiler() t = fluid.transpiler.InferenceTranspiler()
t.transpile(test_program, fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()) t.transpile(test_program, fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace())
# for i in test_program.current_block().ops:
# print i
# sys.exit(0)
conv_op_index = [index for index, value in enumerate(test_program.global_block().ops) if value.type == 'conv2d'] conv_op_index = [index for index, value in enumerate(test_program.global_block().ops) if value.type == 'conv2d']
pooling_op_index = [index for index, value in enumerate(test_program.global_block().ops) if value.type == 'pool2d']
weights_var_name = [] weights_var_name = []
conv_input_var_name = [] conv_input_var_name = []
conv_output_var_name = [] conv_output_var_name = []
# weights_channel = {}
for i in conv_op_index[1:]: for i in conv_op_index[1:]:
weights_var_name.append(test_program.current_block().ops[i].input('Filter')[0]) weights_var_name.append(test_program.current_block().ops[i].input('Filter')[0])
conv_input_var_name.append(test_program.current_block().ops[i].input('Input')[0]) conv_input_var_name.append(test_program.current_block().ops[i].input('Input')[0])
conv_output_var_name.append(test_program.current_block().ops[i].output('Output')[0]) conv_output_var_name.append(test_program.current_block().ops[i].output('Output')[0])
for i in pooling_op_index:
conv_input_var_name.append(test_program.current_block().ops[i].input('X')[0])
conv_output_var_name.append(test_program.current_block().ops[i].output('Out')[0])
not_persistable_vars = (i for i in test_program.list_vars() if not i.persistable) not_persistable_vars = (i for i in test_program.list_vars() if not i.persistable)
back_program = test_program.clone() back_program = test_program.clone()
for i in not_persistable_vars: for i in not_persistable_vars:
...@@ -287,6 +280,7 @@ def eval(args): ...@@ -287,6 +280,7 @@ def eval(args):
break break
int8_prog = back_program.clone() int8_prog = back_program.clone()
# for index, value in enumerate(conv_op_index[1:]): # for index, value in enumerate(conv_op_index[1:]):
# # print index,conv_input_var_name[index], ["{}_scale.input.test".format(conv_input_var_name[index])] # # print index,conv_input_var_name[index], ["{}_scale.input.test".format(conv_input_var_name[index])]
# int8_prog.current_block().ops[value].desc.set_input("Scale_in", ["{}_scale.input.test".format(conv_input_var_name[index])]) # int8_prog.current_block().ops[value].desc.set_input("Scale_in", ["{}_scale.input.test".format(conv_input_var_name[index])])
...@@ -325,6 +319,7 @@ def eval(args): ...@@ -325,6 +319,7 @@ def eval(args):
# for i in int8_prog.current_block().ops[quantize_pos[0] + 2:]: # for i in int8_prog.current_block().ops[quantize_pos[0] + 2:]:
# if i.type == 'conv2d' and i.input('Input')[0] == int8_prog.current_block().ops[quantize_pos[0] + 1].output('Out')[0]: # if i.type == 'conv2d' and i.input('Input')[0] == int8_prog.current_block().ops[quantize_pos[0] + 1].output('Out')[0]:
# i.desc.set_input("Input", ["conv2_quantize_tmp"]) # i.desc.set_input("Input", ["conv2_quantize_tmp"])
# dequantize_pos = get_dequantization_op_pos(int8_prog) # dequantize_pos = get_dequantization_op_pos(int8_prog)
# dequantize_tmp_var = int8_prog.current_block().create_var( # dequantize_tmp_var = int8_prog.current_block().create_var(
# name="dequantize_tmp_var", # name="dequantize_tmp_var",
...@@ -332,7 +327,7 @@ def eval(args): ...@@ -332,7 +327,7 @@ def eval(args):
# persistable=True, # persistable=True,
# #shape= (np.array(fluid.global_scope().find_var('pool2d_0.tmp_0').get_tensor())).shape # #shape= (np.array(fluid.global_scope().find_var('pool2d_0.tmp_0').get_tensor())).shape
# ) # )
# op = int8_prog.current_block()._insert_op( # op = int8_prog.current_block()._insert_op(
# index=dequantize_pos[0] + 1, # index=dequantize_pos[0] + 1,
...@@ -347,7 +342,7 @@ def eval(args): ...@@ -347,7 +342,7 @@ def eval(args):
# int8_prog.current_block().ops[dequantize_pos[0] + 2].desc.set_input("X", ["dequantize_tmp_var"]) # int8_prog.current_block().ops[dequantize_pos[0] + 2].desc.set_input("X", ["dequantize_tmp_var"])
#Step 3 Save the new model #Step 3 Save the new model
print int8_prog # print int8_prog
# for i in int8_prog.current_block().ops: # for i in int8_prog.current_block().ops:
# print '********' # print '********'
# print i # print i
...@@ -367,6 +362,7 @@ def eval(args): ...@@ -367,6 +362,7 @@ def eval(args):
# dot(int8_prog) # dot(int8_prog)
# for i in int8_prog.current_block().ops: # for i in int8_prog.current_block().ops:
# print i # print i
print int8_prog
for batch_id, data in enumerate(val_reader()): for batch_id, data in enumerate(val_reader()):
loss, acc1, acc5 = exe.run(int8_prog, loss, acc1, acc5 = exe.run(int8_prog,
fetch_list=fetch_list, fetch_list=fetch_list,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册