diff --git a/paddlex/cv/models/slim/prune.py b/paddlex/cv/models/slim/prune.py index 810679d3d7cf70a14922a594af3468294f12d29c..08f8603fff612ca0936f7b17f9d4206f14ca7654 100644 --- a/paddlex/cv/models/slim/prune.py +++ b/paddlex/cv/models/slim/prune.py @@ -116,6 +116,21 @@ def channel_prune(program, prune_names, prune_ratios, place, only_graph=False): Returns: paddle.fluid.Program: 裁剪后的Program。 """ + prog_var_shape_dict = {} + for var in program.list_vars(): + try: + prog_var_shape_dict[var.name] = var.shape + except Exception: + pass + index = 0 + for param, ratio in zip(prune_names, prune_ratios): + origin_num = prog_var_shape_dict[param][0] + pruned_num = int(round(origin_num * ratio)) + while origin_num == pruned_num: + ratio -= 0.1 + pruned_num = int(round(origin_num * (ratio))) + prune_ratios[index] = ratio + index += 1 scope = fluid.global_scope() pruner = Pruner() program, _, _ = pruner.prune( @@ -266,6 +281,7 @@ def get_params_ratios(sensitivities_file, eval_metric_loss=0.05): sensitivitives = paddleslim.prune.load_sensitivities(sensitivities_file) params_ratios = paddleslim.prune.get_ratios_by_loss( sensitivitives, eval_metric_loss) + return params_ratios @@ -284,6 +300,19 @@ def cal_model_size(program, place, sensitivities_file, eval_metric_loss=0.05): """ prune_params_ratios = get_params_ratios(sensitivities_file, eval_metric_loss) + prog_var_shape_dict = {} + for var in program.list_vars(): + try: + prog_var_shape_dict[var.name] = var.shape + except Exception: + pass + for param, ratio in prune_params_ratios.items(): + origin_num = prog_var_shape_dict[param][0] + pruned_num = int(round(origin_num * ratio)) + while origin_num == pruned_num: + ratio -= 0.1 + pruned_num = int(round(origin_num * (ratio))) + prune_params_ratios[param] = ratio prune_program = channel_prune( program, list(prune_params_ratios.keys()),