提交 6bafa1ce 编写于 作者: L LDOUBLEV

add prune demo

上级 e0851f2b
......@@ -139,31 +139,6 @@ def main(config, device, logger, vdl_writer):
program.train(config, train_dataloader, valid_dataloader, device, model,
loss_class, optimizer, lr_scheduler, post_process_class,
eval_class, pre_best_model_dict, logger, vdl_writer)
mode = 'infer'
if mode == 'infer':
from paddle.jit import to_static
infer_shape = [3, -1, -1]
if config['Architecture']['model_type'] == "rec":
infer_shape = [3, 32, -1] # for rec model, H must be 32
if 'Transform' in config['Architecture'] and config['Architecture'][
'Transform'] is not None and config['Architecture'][
'Transform']['name'] == 'TPS':
logger.info(
'When there is tps in the network, variable length input is not supported, and the input size needs to be the same as during training'
)
infer_shape[-1] = 100
model = to_static(
model,
input_spec=[
paddle.static.InputSpec(
shape=[None] + infer_shape, dtype='float32')
])
save_path = '{}/inference'.format(config['Global'][
'save_inference_dir'])
paddle.jit.save(model, save_path)
logger.info('inference model is saved to {}'.format(save_path))
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册