From 6a89ed138084ffa6374493a23de59ca2accf2cfa Mon Sep 17 00:00:00 2001 From: jiangjiajun Date: Tue, 22 Sep 2020 15:03:34 +0000 Subject: [PATCH] fix scope problem for slim --- paddlex/cv/models/slim/prune.py | 20 ++++++++------ paddlex/cv/models/slim/prune_config.py | 36 +++++++++++++++++++++++--- paddlex/cv/models/slim/visualize.py | 2 +- 3 files changed, 45 insertions(+), 13 deletions(-) diff --git a/paddlex/cv/models/slim/prune.py b/paddlex/cv/models/slim/prune.py index 4ff3e23..f2dc8da 100644 --- a/paddlex/cv/models/slim/prune.py +++ b/paddlex/cv/models/slim/prune.py @@ -104,7 +104,7 @@ def sensitivity(program, return sensitivities -def channel_prune(program, prune_names, prune_ratios, place, only_graph=False): +def channel_prune(program, prune_names, prune_ratios, place, only_graph=False, scope=None): """通道裁剪。 Args: @@ -134,7 +134,8 @@ def channel_prune(program, prune_names, prune_ratios, place, only_graph=False): pruned_num = int(round(origin_num * (ratio))) prune_ratios[index] = ratio index += 1 - scope = fluid.global_scope() + if scope is None: + scope = fluid.global_scope() pruner = Pruner() program, _, _ = pruner.prune( program, @@ -175,12 +176,12 @@ def prune_program(model, prune_params_ratios=None): prune_params_ratios[prune_name] for prune_name in prune_names ] model.train_prog = channel_prune(train_prog, prune_names, prune_ratios, - place) + place, scope=model.scope) model.test_prog = channel_prune( - eval_prog, prune_names, prune_ratios, place, only_graph=True) + eval_prog, prune_names, prune_ratios, place, only_graph=True, scope=model.scope) -def update_program(program, model_dir, place): +def update_program(program, model_dir, place, scope=None): """根据裁剪信息更新Program和参数。 Args: @@ -197,10 +198,12 @@ def update_program(program, model_dir, place): shapes = yaml.load(f.read(), Loader=yaml.Loader) for param, shape in shapes.items(): graph.var(param).set_shape(shape) + if scope is None: + scope = fluid.global_scope() for block in program.blocks: for param in block.all_parameters(): if param.name in shapes: - param_tensor = fluid.global_scope().find_var( + param_tensor = scope.find_var( param.name).get_tensor() param_tensor.set( np.zeros(list(shapes[param.name])).astype('float32'), @@ -293,7 +296,7 @@ def get_params_ratios(sensitivities_file, eval_metric_loss=0.05): return params_ratios -def cal_model_size(program, place, sensitivities_file, eval_metric_loss=0.05): +def cal_model_size(program, place, sensitivities_file, eval_metric_loss=0.05, scope=None): """在可容忍的精度损失下,计算裁剪后模型大小相对于当前模型大小的比例。 Args: @@ -326,7 +329,8 @@ def cal_model_size(program, place, sensitivities_file, eval_metric_loss=0.05): list(prune_params_ratios.keys()), list(prune_params_ratios.values()), place, - only_graph=True) + only_graph=True, + scope=scope) origin_size = 0 new_size = 0 for var in program.list_vars(): diff --git a/paddlex/cv/models/slim/prune_config.py b/paddlex/cv/models/slim/prune_config.py index d85867e..6bfe353 100644 --- a/paddlex/cv/models/slim/prune_config.py +++ b/paddlex/cv/models/slim/prune_config.py @@ -171,10 +171,14 @@ def get_prune_params(model): model_type.startswith('ShuffleNetV2'): for block in program.blocks: for param in block.all_parameters(): - pd_var = fluid.global_scope().find_var(param.name) - pd_param = pd_var.get_tensor() - if len(np.array(pd_param).shape) == 4: - prune_names.append(param.name) + pd_var = model.scope.find_var(param.name) + try: + pd_param = pd_var.get_tensor() + if len(np.array(pd_param).shape) == 4: + prune_names.append(param.name) + except Exception as e: + print("None Tensor Name: ", param.name) + print("Error message: {}".format(e)) if model_type == 'AlexNet': prune_names.remove('conv5_weights') if model_type == 'ShuffleNetV2': @@ -285,11 +289,35 @@ def get_prune_params(model): prune_names.remove(i) elif model_type.startswith('DeepLabv3p'): + if model_type.lower() == "deeplabv3p_mobilenetv3_large_x1_0_ssld": + params_not_prune = [ + 'last_1x1_conv_weights', 'conv14_se_2_weights', + 'conv16_depthwise_weights', 'conv13_depthwise_weights', + 'conv15_se_2_weights', 'conv2_depthwise_weights', + 'conv6_depthwise_weights', 'conv8_depthwise_weights', + 'fc_weights', 'conv3_depthwise_weights', 'conv7_se_2_weights', + 'conv16_expand_weights', 'conv16_se_2_weights', + 'conv10_depthwise_weights', 'conv11_depthwise_weights', + 'conv15_expand_weights', 'conv5_expand_weights', + 'conv15_depthwise_weights', 'conv14_depthwise_weights', + 'conv12_se_2_weights', 'conv1_weights', + 'conv13_expand_weights', 'conv_last_weights', + 'conv12_depthwise_weights', 'conv13_se_2_weights', + 'conv12_expand_weights', 'conv5_depthwise_weights', + 'conv6_se_2_weights', 'conv10_expand_weights', + 'conv9_depthwise_weights', 'conv6_expand_weights', + 'conv5_se_2_weights', 'conv14_expand_weights', + 'conv4_depthwise_weights', 'conv7_expand_weights', + 'conv7_depthwise_weights' + ] for param in program.global_block().all_parameters(): if 'weight' not in param.name: continue if 'dwise' in param.name or 'depthwise' in param.name or 'logit' in param.name: continue + if model_type.lower() == "deeplabv3p_mobilenetv3_large_x1_0_ssld": + if param.name in params_not_prune: + continue prune_names.append(param.name) params_not_prune = [ 'xception_{}/exit_flow/block2/separable_conv3/pointwise/weights'. diff --git a/paddlex/cv/models/slim/visualize.py b/paddlex/cv/models/slim/visualize.py index 4be6721..a8cb444 100644 --- a/paddlex/cv/models/slim/visualize.py +++ b/paddlex/cv/models/slim/visualize.py @@ -42,7 +42,7 @@ def visualize(model, sensitivities_file, save_dir='./'): y = list() for loss_thresh in tqdm.tqdm(list(np.arange(0.05, 1, 0.05))): prune_ratio = 1 - cal_model_size( - program, place, sensitivities_file, eval_metric_loss=loss_thresh) + program, place, sensitivities_file, eval_metric_loss=loss_thresh, scope=model.scope) x.append(prune_ratio) y.append(loss_thresh) plt.plot(x, y, color='green', linewidth=0.5, marker='o', markersize=3) -- GitLab