量化训练模型删除指定op转换功能
Created by: wanghaoshuang
import paddle.fluid as fluid
from pyramidbox_test import PyramidBox
from paddle.fluid.framework import IrGraph
from paddle.fluid import core
from paddle.fluid.contrib.slim.quantization.quantization_pass import *
import sys
if __name__ == '__main__':
float_model_path_src = 'faceboxes_V016_float32_compress_0327_float'
place = fluid.CPUPlace()
exe = fluid.Executor(place)
# step1: define infer network without assign op
main_program = fluid.Program()
startup_program = fluid.Program()
with fluid.program_guard(main_program, startup_program):
print('debug: construct network')
network = PyramidBox(
data_shape=[3, 240, 320],
sub_network=True,
is_infer=True)
inference_program, nmsed_out = network.infer(main_program)
print('nmsed_out name: {}'.format(nmsed_out.name))
# step2: init variables
exe.run(startup_program)
# step3: insert quantzation operators into infer network
transform_pass = QuantizationTransformPass(
scope=fluid.executor.global_scope(),
place=place,
weight_bits=8,
activation_bits=8,
activation_quantize_type='range_abs_max')
eval_graph = IrGraph(core.Graph(inference_program.desc), for_test=True)
transform_pass.apply(eval_graph)
# step4: load float weights into global scope
fluid.io.load_inference_model(float_model_path_src, exe,
model_filename='model',
params_filename='weights')
# step5: convert to int8 model
freeze_pass = QuantizationFreezePass(
scope=fluid.executor.global_scope(),
place=place,
weight_bits=8,
activation_bits=8)
freeze_pass.apply(eval_graph)
convert_int8_pass = ConvertToInt8Pass(
scope=fluid.executor.global_scope(),
place=place)
convert_int8_pass.apply(eval_graph)
eval_graph.draw('.','eval')
# step6: save int8 model into filesystem
program = eval_graph.to_program()
out_vars = [program.global_block().var('detection_output_0.tmp_0')]
fluid.io.save_inference_model(
"./output_int8",
['image'],
out_vars,
exe,
main_program=program,
model_filename='model',
params_filename='weights',
export_for_deployment=True)
# step7: load and check
program, feed_target_names, fetch_targets = fluid.io.load_inference_model("./output_int8", exe,
model_filename='model',
params_filename='weights')
print("feed_target_names: {}".format(feed_target_names))
print("fetch_targets: {}".format(fetch_targets))
for op in program.global_block().ops:
print("op: {}".format(op.type))
for var in program.list_vars():
if var.persistable:
print("var: {}".format(var.name))