From 76f404690e921e59ec2a150c82c9bb4d032da87d Mon Sep 17 00:00:00 2001 From: LDOUBLEV Date: Wed, 12 May 2021 10:29:11 +0800 Subject: [PATCH] revert prune --- deploy/slim/prune/sensitivity_anal.py | 48 ++++++++++----------------- 1 file changed, 17 insertions(+), 31 deletions(-) diff --git a/deploy/slim/prune/sensitivity_anal.py b/deploy/slim/prune/sensitivity_anal.py index ecf990b1..bd2b9649 100644 --- a/deploy/slim/prune/sensitivity_anal.py +++ b/deploy/slim/prune/sensitivity_anal.py @@ -24,14 +24,6 @@ sys.path.append(__dir__) sys.path.append(os.path.join(__dir__, '..', '..', '..')) sys.path.append(os.path.join(__dir__, '..', '..', '..', 'tools')) -import json -import cv2 -import paddle -from paddle import fluid -import paddleslim as slim -from copy import deepcopy -from tools import program - import paddle import paddle.distributed as dist from ppocr.data import build_dataloader @@ -46,28 +38,14 @@ import tools.program as program dist.get_world_size() -def get_pruned_params(parameters, mode="det"): - if mode == "det": - skip_prune_params = [ - "conv2d_56.w_0", "conv2d_54.w_0", "conv2d_51.w_0", - "conv_last_weights", "conv14_linear_weights", - "conv13_expand_weights", "conv12_linear_weights", - "conv12_expand_weights", "conv7_expand_weights", - "conv8_expand_weights", "conv8_linear_weights", - "conv5_linear_weights", "conv5_expand_weights", - "conv3_linear_weights" - ] - skip_prune_params = skip_prune_params + ['conv2d_53.w_0'] - else: - skip_prune_params = None +def get_pruned_params(parameters): params = [] for param in parameters: if len( param.shape ) == 4 and 'depthwise' not in param.name and 'transpose' not in param.name and "conv2d_57" not in param.name and "conv2d_56" not in param.name: - if param.name not in skip_prune_params: - params.append(param.name) + params.append(param.name) return params @@ -118,6 +96,11 @@ def main(config, device, logger, vdl_writer): # load pretrain model pre_best_model_dict = init_model(config, model, logger, optimizer) + logger.info('train dataloader has {} iters, valid dataloader has {} iters'. + format(len(train_dataloader), len(valid_dataloader))) + # build metric + eval_class = build_metric(config['Metric']) + logger.info('train dataloader has {} iters, valid dataloader has {} iters'. format(len(train_dataloader), len(valid_dataloader))) @@ -127,20 +110,22 @@ def main(config, device, logger, vdl_writer): logger.info(f"metric['hmean']: {metric['hmean']}") return metric['hmean'] - pruner.sensitive( + params_sensitive = pruner.sensitive( eval_func=eval_fn, sen_file="./sen.pickle", skip_vars=[ "conv2d_57.w_0", "conv2d_transpose_2.w_0", "conv2d_transpose_3.w_0" ]) - params = get_pruned_params(model.parameters()) - ratios = {} - # set the prune ratio is 0.2 - for param in params: - ratios[param] = 0.2 + logger.info( + "The sensitivity analysis results of model parameters saved in sen.pickle" + ) + # calculate pruned params's ratio + params_sensitive = pruner._get_ratios_by_loss(params_sensitive, loss=0.02) + for key in params_sensitive.keys(): + logger.info(f"{key}, {params_sensitive[key]}") - plan = pruner.prune_vars(ratios, [0]) + plan = pruner.prune_vars(params_sensitive, [0]) for param in model.parameters(): if ("weights" in param.name and "conv" in param.name) or ( "w_0" in param.name and "conv2d" in param.name): @@ -150,6 +135,7 @@ def main(config, device, logger, vdl_writer): logger.info(f"FLOPs after pruning: {flops}") # start train + program.train(config, train_dataloader, valid_dataloader, device, model, loss_class, optimizer, lr_scheduler, post_process_class, eval_class, pre_best_model_dict, logger, vdl_writer) -- GitLab