From 6bafa1cec4116813d565746e089abe4cea12bf3f Mon Sep 17 00:00:00 2001 From: LDOUBLEV Date: Fri, 19 Feb 2021 12:39:43 +0800 Subject: [PATCH] add prune demo --- deploy/slim/prune/sensitivity_anal.py | 25 ------------------------- 1 file changed, 25 deletions(-) diff --git a/deploy/slim/prune/sensitivity_anal.py b/deploy/slim/prune/sensitivity_anal.py index 6abd9815..bd2b9649 100644 --- a/deploy/slim/prune/sensitivity_anal.py +++ b/deploy/slim/prune/sensitivity_anal.py @@ -139,31 +139,6 @@ def main(config, device, logger, vdl_writer): program.train(config, train_dataloader, valid_dataloader, device, model, loss_class, optimizer, lr_scheduler, post_process_class, eval_class, pre_best_model_dict, logger, vdl_writer) - mode = 'infer' - if mode == 'infer': - from paddle.jit import to_static - - infer_shape = [3, -1, -1] - if config['Architecture']['model_type'] == "rec": - infer_shape = [3, 32, -1] # for rec model, H must be 32 - if 'Transform' in config['Architecture'] and config['Architecture'][ - 'Transform'] is not None and config['Architecture'][ - 'Transform']['name'] == 'TPS': - logger.info( - 'When there is tps in the network, variable length input is not supported, and the input size needs to be the same as during training' - ) - infer_shape[-1] = 100 - model = to_static( - model, - input_spec=[ - paddle.static.InputSpec( - shape=[None] + infer_shape, dtype='float32') - ]) - - save_path = '{}/inference'.format(config['Global'][ - 'save_inference_dir']) - paddle.jit.save(model, save_path) - logger.info('inference model is saved to {}'.format(save_path)) if __name__ == '__main__': -- GitLab