提交 4702eb8c 编写于 作者: Z Zhang, Guoming

Fix the shape issue on quantize op

上级 4ab641e2
...@@ -176,7 +176,7 @@ def eval(args): ...@@ -176,7 +176,7 @@ def eval(args):
acc_top5 = fluid.layers.accuracy(input=out, label=label, k=5) acc_top5 = fluid.layers.accuracy(input=out, label=label, k=5)
test_program = fluid.default_main_program().clone(for_test=True) test_program = fluid.default_main_program().clone(for_test=True)
if with_memory_optimization: if with_memory_optimization:
fluid.memory_optimize(fluid.default_main_program()) fluid.memory_optimize(fluid.default_main_program())
...@@ -197,7 +197,6 @@ def eval(args): ...@@ -197,7 +197,6 @@ def eval(args):
# sys.exit(0) # sys.exit(0)
conv_op_index = [index for index, value in enumerate(test_program.global_block().ops) if value.type == 'conv2d'] conv_op_index = [index for index, value in enumerate(test_program.global_block().ops) if value.type == 'conv2d']
pooling_op_index = [index for index, value in enumerate(test_program.global_block().ops) if value.type == 'pool2d'] pooling_op_index = [index for index, value in enumerate(test_program.global_block().ops) if value.type == 'pool2d']
print (conv_op_index)
weights_var_name = [] weights_var_name = []
conv_input_var_name = [] conv_input_var_name = []
conv_output_var_name = [] conv_output_var_name = []
...@@ -211,9 +210,8 @@ def eval(args): ...@@ -211,9 +210,8 @@ def eval(args):
conv_input_var_name.append(test_program.current_block().ops[i].input('X')[0]) conv_input_var_name.append(test_program.current_block().ops[i].input('X')[0])
conv_output_var_name.append(test_program.current_block().ops[i].output('Out')[0]) conv_output_var_name.append(test_program.current_block().ops[i].output('Out')[0])
not_persistable_vars = (i for i in test_program.list_vars() if not i.persistable) not_persistable_vars = (i for i in test_program.list_vars() if not i.persistable)
back_program = test_program.clone()
for i in not_persistable_vars: for i in not_persistable_vars:
i.persistable= True i.persistable= True
...@@ -288,30 +286,31 @@ def eval(args): ...@@ -288,30 +286,31 @@ def eval(args):
feed=feeder.feed(data)) feed=feeder.feed(data))
break break
int8_prog = test_program.clone() int8_prog = back_program.clone()
for index, value in enumerate(conv_op_index[1:]): # for index, value in enumerate(conv_op_index[1:]):
# print index,conv_input_var_name[index], ["{}_scale.input.test".format(conv_input_var_name[index])] # # print index,conv_input_var_name[index], ["{}_scale.input.test".format(conv_input_var_name[index])]
int8_prog.current_block().ops[value].desc.set_input("Scale_in", ["{}_scale.input.test".format(conv_input_var_name[index])]) # int8_prog.current_block().ops[value].desc.set_input("Scale_in", ["{}_scale.input.test".format(conv_input_var_name[index])])
int8_prog.current_block().ops[value].desc.set_input("Scale_out", ["{}_scale.output.test".format(conv_output_var_name[index])]) # int8_prog.current_block().ops[value].desc.set_input("Scale_out", ["{}_scale.output.test".format(conv_output_var_name[index])])
int8_prog.current_block().ops[value].desc.set_input("Scale_weights", ["{}_scale.weights.test".format(weights_var_name[index])]) # int8_prog.current_block().ops[value].desc.set_input("Scale_weights", ["{}_scale.weights.test".format(weights_var_name[index])])
if int8_prog.current_block().ops[value].desc.input("ResidualData"): # if int8_prog.current_block().ops[value].desc.input("ResidualData"):
name = int8_prog.current_block().ops[value].desc.input("ResidualData")[0] # name = int8_prog.current_block().ops[value].desc.input("ResidualData")[0]
int8_prog.current_block().ops[value].desc.set_input("Scale_in_eltwise", ["{}_scale.output.test".format(name)]) # int8_prog.current_block().ops[value].desc.set_input("Scale_in_eltwise", ["{}_scale.output.test".format(name)])
quantize_pos = get_quantization_op_pos(int8_prog) quantize_pos = get_quantization_op_pos(int8_prog)
conv2_quantize_tmp = int8_prog.current_block().create_var( conv2_quantize_tmp = int8_prog.current_block().create_var(
name="conv2_quantize_tmp", name="conv2_quantize_tmp",
dtype="float32", dtype=core.VarDesc.VarType.UINT8,
persistable=True, # persistable=True,
#shape= (np.array(fluid.global_scope().find_var('pool2d_0.tmp_0').get_tensor())).shape # lod_level= 0,
# shape= shape
) )
op = int8_prog.current_block()._insert_op( op = int8_prog.current_block()._insert_op(
index=quantize_pos[0], index=quantize_pos[0] ,
type= "quantize", type="quantize",
inputs={"Input": int8_prog.current_block().ops[quantize_pos[0] - 1].output('Out')[0], inputs={"Input": int8_prog.current_block().ops[quantize_pos[0] - 1].output('Out')[0],
"Scale": "{}_scale.input.test".format(conv_input_var_name[1])}, "Scale": "{}_scale.input.test".format(conv_input_var_name[1])},
...@@ -321,33 +320,34 @@ def eval(args): ...@@ -321,33 +320,34 @@ def eval(args):
) )
op._set_attr("data_format", "NCHW") op._set_attr("data_format", "NCHW")
op._set_attr("use_mkldnn", 1) op._set_attr("use_mkldnn", 1)
int8_prog.current_block().ops[quantize_pos[0] + 1 ].desc.set_input("Input", ["conv2_quantize_tmp"])
for i in int8_prog.current_block().ops[quantize_pos[0] + 2:]:
if i.type == 'conv2d' and i.input('Input')[0] == int8_prog.current_block().ops[quantize_pos[0] - 1].output('Out')[0]:
i.desc.set_input("Input", ["conv2_quantize_tmp"])
dequantize_pos = get_dequantization_op_pos(int8_prog)
dequantize_tmp_var = int8_prog.current_block().create_var(
name="dequantize_tmp_var",
dtype="float32",
persistable=True,
#shape= (np.array(fluid.global_scope().find_var('pool2d_0.tmp_0').get_tensor())).shape
)
op = int8_prog.current_block()._insert_op( # int8_prog.current_block().ops[quantize_pos[0] + 1 ].desc.set_input("Input", ["conv2_quantize_tmp"])
index=dequantize_pos[0] + 1, # for i in int8_prog.current_block().ops[quantize_pos[0] + 2:]:
# if i.type == 'conv2d' and i.input('Input')[0] == int8_prog.current_block().ops[quantize_pos[0] + 1].output('Out')[0]:
# i.desc.set_input("Input", ["conv2_quantize_tmp"])
# dequantize_pos = get_dequantization_op_pos(int8_prog)
# dequantize_tmp_var = int8_prog.current_block().create_var(
# name="dequantize_tmp_var",
# dtype="float32",
# persistable=True,
# #shape= (np.array(fluid.global_scope().find_var('pool2d_0.tmp_0').get_tensor())).shape
# )
# op = int8_prog.current_block()._insert_op(
# index=dequantize_pos[0] + 1,
type= "dequantize", # type= "dequantize",
inputs={"Input": int8_prog.current_block().ops[dequantize_pos[0]].output('Out')[0], # inputs={"Input": int8_prog.current_block().ops[dequantize_pos[0]].output('Out')[0],
"Scale": "{}_scale.output.test".format( int8_prog.current_block().ops[dequantize_pos[0]].output('Out')[0])}, # "Scale": "{}_scale.output.test".format( int8_prog.current_block().ops[dequantize_pos[0]].output('Out')[0])},
outputs={"Output": dequantize_tmp_var}, # outputs={"Output": dequantize_tmp_var},
) # )
int8_prog.current_block().ops[dequantize_pos[0] + 2].desc.set_input("X", ["dequantize_tmp_var"]) # int8_prog.current_block().ops[dequantize_pos[0] + 2].desc.set_input("X", ["dequantize_tmp_var"])
#Step 3 Save the new model #Step 3 Save the new model
print int8_prog
# for i in int8_prog.current_block().ops: # for i in int8_prog.current_block().ops:
# print '********' # print '********'
# print i # print i
...@@ -362,9 +362,28 @@ def eval(args): ...@@ -362,9 +362,28 @@ def eval(args):
# print k, i.output(k)[0] # print k, i.output(k)[0]
# print conv_op_index # print conv_op_index
# print dequantize_pos # print dequantize_pos
if DEBUG: # sys.exit(0)
dot(int8_prog) # if DEBUG:
# dot(int8_prog)
# for i in int8_prog.current_block().ops:
# print i
for batch_id, data in enumerate(val_reader()):
loss, acc1, acc5 = exe.run(int8_prog,
fetch_list=fetch_list,
feed=feeder.feed(data))
loss = np.mean(loss)
acc1 = np.mean(acc1)
acc5 = np.mean(acc5)
test_info[0].append(loss * len(data))
test_info[1].append(acc1 * len(data))
test_info[2].append(acc5 * len(data))
cnt += len(data)
if batch_id % 10 == 0:
print("Testbatch {0},loss {1}, "
"acc1 {2},acc5 {3}".format(batch_id, \
loss, acc1, acc5))
sys.stdout.flush()
break
with open("__model_quantized__", "wb") as f: with open("__model_quantized__", "wb") as f:
f.write(int8_prog.desc.serialize_to_string()) f.write(int8_prog.desc.serialize_to_string())
......
...@@ -28,7 +28,10 @@ class QuantOp : public framework::OperatorWithKernel { ...@@ -28,7 +28,10 @@ class QuantOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override{} void InferShape(framework::InferShapeContext* ctx) const override{
ctx->SetOutputDim("Output", ctx->GetInputDim("Input"));
ctx->ShareLoD("Input", /*->*/ "Output");
}
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册