提交 ee3df4c2 编写于 作者: littletomatodonkey's avatar littletomatodonkey

add support for quant distillation

上级 7608f163
...@@ -37,6 +37,17 @@ from paddleslim.dygraph.quant import QAT ...@@ -37,6 +37,17 @@ from paddleslim.dygraph.quant import QAT
from ppocr.data import build_dataloader from ppocr.data import build_dataloader
def export_single_model(quanter, model, infer_shape, save_path, logger):
quanter.save_quantized_model(
model,
save_path,
input_spec=[
paddle.static.InputSpec(
shape=[None] + infer_shape, dtype='float32')
])
logger.info('inference QAT model is saved to {}'.format(save_path))
def main(): def main():
############################################################################################################ ############################################################################################################
# 1. quantization configs # 1. quantization configs
...@@ -76,7 +87,14 @@ def main(): ...@@ -76,7 +87,14 @@ def main():
# for rec algorithm # for rec algorithm
if hasattr(post_process_class, 'character'): if hasattr(post_process_class, 'character'):
char_num = len(getattr(post_process_class, 'character')) char_num = len(getattr(post_process_class, 'character'))
if config['Architecture']["algorithm"] in ["Distillation",
]: # distillation model
for key in config['Architecture']["Models"]:
config['Architecture']["Models"][key]["Head"][
'out_channels'] = char_num
else: # base rec model
config['Architecture']["Head"]['out_channels'] = char_num config['Architecture']["Head"]['out_channels'] = char_num
model = build_model(config['Architecture']) model = build_model(config['Architecture'])
# get QAT model # get QAT model
...@@ -93,24 +111,27 @@ def main(): ...@@ -93,24 +111,27 @@ def main():
valid_dataloader = build_dataloader(config, 'Eval', device, logger) valid_dataloader = build_dataloader(config, 'Eval', device, logger)
# start eval # start eval
metirc = program.eval(model, valid_dataloader, post_process_class, model_type = config['Architecture']['model_type']
eval_class) metric = program.eval(model, valid_dataloader, post_process_class,
eval_class, model_type)
logger.info('metric eval ***************') logger.info('metric eval ***************')
for k, v in metirc.items(): for k, v in metric.items():
logger.info('{}:{}'.format(k, v)) logger.info('{}:{}'.format(k, v))
save_path = '{}/inference'.format(config['Global']['save_inference_dir'])
infer_shape = [3, 32, 100] if config['Architecture'][ infer_shape = [3, 32, 100] if config['Architecture'][
'model_type'] != "det" else [3, 640, 640] 'model_type'] != "det" else [3, 640, 640]
quanter.save_quantized_model( save_path = config["Global"]["save_inference_dir"]
model,
save_path, arch_config = config["Architecture"]
input_spec=[ if arch_config["algorithm"] in ["Distillation", ]: # distillation model
paddle.static.InputSpec( for idx, name in enumerate(model.model_name_list):
shape=[None] + infer_shape, dtype='float32') sub_model_save_path = os.path.join(save_path, name, "inference")
]) export_single_model(quanter, model.model_list[idx], infer_shape,
logger.info('inference QAT model is saved to {}'.format(save_path)) sub_model_save_path, logger)
else:
save_path = os.path.join(save_path, "inference")
export_single_model(quanter, model, infer_shape, save_path, logger)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -109,9 +109,18 @@ def main(config, device, logger, vdl_writer): ...@@ -109,9 +109,18 @@ def main(config, device, logger, vdl_writer):
# for rec algorithm # for rec algorithm
if hasattr(post_process_class, 'character'): if hasattr(post_process_class, 'character'):
char_num = len(getattr(post_process_class, 'character')) char_num = len(getattr(post_process_class, 'character'))
if config['Architecture']["algorithm"] in ["Distillation",
]: # distillation model
for key in config['Architecture']["Models"]:
config['Architecture']["Models"][key]["Head"][
'out_channels'] = char_num
else: # base rec model
config['Architecture']["Head"]['out_channels'] = char_num config['Architecture']["Head"]['out_channels'] = char_num
model = build_model(config['Architecture']) model = build_model(config['Architecture'])
quanter = QAT(config=quant_config, act_preprocess=PACT)
quanter.quantize(model)
if config['Global']['distributed']: if config['Global']['distributed']:
model = paddle.DataParallel(model) model = paddle.DataParallel(model)
...@@ -132,8 +141,6 @@ def main(config, device, logger, vdl_writer): ...@@ -132,8 +141,6 @@ def main(config, device, logger, vdl_writer):
logger.info('train dataloader has {} iters, valid dataloader has {} iters'. logger.info('train dataloader has {} iters, valid dataloader has {} iters'.
format(len(train_dataloader), len(valid_dataloader))) format(len(train_dataloader), len(valid_dataloader)))
quanter = QAT(config=quant_config, act_preprocess=PACT)
quanter.quantize(model)
# start train # start train
program.train(config, train_dataloader, valid_dataloader, device, model, program.train(config, train_dataloader, valid_dataloader, device, model,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册