未验证 提交 8f3753bf 编写于 作者: J Jason 提交者: GitHub

Merge pull request #329 from PaddlePaddle/develop_tmp1

fix scope problem for slim
...@@ -69,7 +69,7 @@ def load_model(model_dir, fixed_input_shape=None): ...@@ -69,7 +69,7 @@ def load_model(model_dir, fixed_input_shape=None):
if status == "Prune": if status == "Prune":
from .slim.prune import update_program from .slim.prune import update_program
model.test_prog = update_program(model.test_prog, model_dir, model.test_prog = update_program(model.test_prog, model_dir,
model.places[0]) model.places[0], scope=model_scope)
import pickle import pickle
with open(osp.join(model_dir, 'model.pdparams'), 'rb') as f: with open(osp.join(model_dir, 'model.pdparams'), 'rb') as f:
load_dict = pickle.load(f) load_dict = pickle.load(f)
......
...@@ -104,7 +104,7 @@ def sensitivity(program, ...@@ -104,7 +104,7 @@ def sensitivity(program,
return sensitivities return sensitivities
def channel_prune(program, prune_names, prune_ratios, place, only_graph=False): def channel_prune(program, prune_names, prune_ratios, place, only_graph=False, scope=None):
"""通道裁剪。 """通道裁剪。
Args: Args:
...@@ -134,7 +134,8 @@ def channel_prune(program, prune_names, prune_ratios, place, only_graph=False): ...@@ -134,7 +134,8 @@ def channel_prune(program, prune_names, prune_ratios, place, only_graph=False):
pruned_num = int(round(origin_num * (ratio))) pruned_num = int(round(origin_num * (ratio)))
prune_ratios[index] = ratio prune_ratios[index] = ratio
index += 1 index += 1
scope = fluid.global_scope() if scope is None:
scope = fluid.global_scope()
pruner = Pruner() pruner = Pruner()
program, _, _ = pruner.prune( program, _, _ = pruner.prune(
program, program,
...@@ -175,12 +176,12 @@ def prune_program(model, prune_params_ratios=None): ...@@ -175,12 +176,12 @@ def prune_program(model, prune_params_ratios=None):
prune_params_ratios[prune_name] for prune_name in prune_names prune_params_ratios[prune_name] for prune_name in prune_names
] ]
model.train_prog = channel_prune(train_prog, prune_names, prune_ratios, model.train_prog = channel_prune(train_prog, prune_names, prune_ratios,
place) place, scope=model.scope)
model.test_prog = channel_prune( model.test_prog = channel_prune(
eval_prog, prune_names, prune_ratios, place, only_graph=True) eval_prog, prune_names, prune_ratios, place, only_graph=True, scope=model.scope)
def update_program(program, model_dir, place): def update_program(program, model_dir, place, scope=None):
"""根据裁剪信息更新Program和参数。 """根据裁剪信息更新Program和参数。
Args: Args:
...@@ -197,10 +198,12 @@ def update_program(program, model_dir, place): ...@@ -197,10 +198,12 @@ def update_program(program, model_dir, place):
shapes = yaml.load(f.read(), Loader=yaml.Loader) shapes = yaml.load(f.read(), Loader=yaml.Loader)
for param, shape in shapes.items(): for param, shape in shapes.items():
graph.var(param).set_shape(shape) graph.var(param).set_shape(shape)
if scope is None:
scope = fluid.global_scope()
for block in program.blocks: for block in program.blocks:
for param in block.all_parameters(): for param in block.all_parameters():
if param.name in shapes: if param.name in shapes:
param_tensor = fluid.global_scope().find_var( param_tensor = scope.find_var(
param.name).get_tensor() param.name).get_tensor()
param_tensor.set( param_tensor.set(
np.zeros(list(shapes[param.name])).astype('float32'), np.zeros(list(shapes[param.name])).astype('float32'),
...@@ -293,7 +296,7 @@ def get_params_ratios(sensitivities_file, eval_metric_loss=0.05): ...@@ -293,7 +296,7 @@ def get_params_ratios(sensitivities_file, eval_metric_loss=0.05):
return params_ratios return params_ratios
def cal_model_size(program, place, sensitivities_file, eval_metric_loss=0.05): def cal_model_size(program, place, sensitivities_file, eval_metric_loss=0.05, scope=None):
"""在可容忍的精度损失下,计算裁剪后模型大小相对于当前模型大小的比例。 """在可容忍的精度损失下,计算裁剪后模型大小相对于当前模型大小的比例。
Args: Args:
...@@ -326,7 +329,8 @@ def cal_model_size(program, place, sensitivities_file, eval_metric_loss=0.05): ...@@ -326,7 +329,8 @@ def cal_model_size(program, place, sensitivities_file, eval_metric_loss=0.05):
list(prune_params_ratios.keys()), list(prune_params_ratios.keys()),
list(prune_params_ratios.values()), list(prune_params_ratios.values()),
place, place,
only_graph=True) only_graph=True,
scope=scope)
origin_size = 0 origin_size = 0
new_size = 0 new_size = 0
for var in program.list_vars(): for var in program.list_vars():
......
...@@ -171,10 +171,14 @@ def get_prune_params(model): ...@@ -171,10 +171,14 @@ def get_prune_params(model):
model_type.startswith('ShuffleNetV2'): model_type.startswith('ShuffleNetV2'):
for block in program.blocks: for block in program.blocks:
for param in block.all_parameters(): for param in block.all_parameters():
pd_var = fluid.global_scope().find_var(param.name) pd_var = model.scope.find_var(param.name)
pd_param = pd_var.get_tensor() try:
if len(np.array(pd_param).shape) == 4: pd_param = pd_var.get_tensor()
prune_names.append(param.name) if len(np.array(pd_param).shape) == 4:
prune_names.append(param.name)
except Exception as e:
print("None Tensor Name: ", param.name)
print("Error message: {}".format(e))
if model_type == 'AlexNet': if model_type == 'AlexNet':
prune_names.remove('conv5_weights') prune_names.remove('conv5_weights')
if model_type == 'ShuffleNetV2': if model_type == 'ShuffleNetV2':
...@@ -285,11 +289,35 @@ def get_prune_params(model): ...@@ -285,11 +289,35 @@ def get_prune_params(model):
prune_names.remove(i) prune_names.remove(i)
elif model_type.startswith('DeepLabv3p'): elif model_type.startswith('DeepLabv3p'):
if model_type.lower() == "deeplabv3p_mobilenetv3_large_x1_0_ssld":
params_not_prune = [
'last_1x1_conv_weights', 'conv14_se_2_weights',
'conv16_depthwise_weights', 'conv13_depthwise_weights',
'conv15_se_2_weights', 'conv2_depthwise_weights',
'conv6_depthwise_weights', 'conv8_depthwise_weights',
'fc_weights', 'conv3_depthwise_weights', 'conv7_se_2_weights',
'conv16_expand_weights', 'conv16_se_2_weights',
'conv10_depthwise_weights', 'conv11_depthwise_weights',
'conv15_expand_weights', 'conv5_expand_weights',
'conv15_depthwise_weights', 'conv14_depthwise_weights',
'conv12_se_2_weights', 'conv1_weights',
'conv13_expand_weights', 'conv_last_weights',
'conv12_depthwise_weights', 'conv13_se_2_weights',
'conv12_expand_weights', 'conv5_depthwise_weights',
'conv6_se_2_weights', 'conv10_expand_weights',
'conv9_depthwise_weights', 'conv6_expand_weights',
'conv5_se_2_weights', 'conv14_expand_weights',
'conv4_depthwise_weights', 'conv7_expand_weights',
'conv7_depthwise_weights'
]
for param in program.global_block().all_parameters(): for param in program.global_block().all_parameters():
if 'weight' not in param.name: if 'weight' not in param.name:
continue continue
if 'dwise' in param.name or 'depthwise' in param.name or 'logit' in param.name: if 'dwise' in param.name or 'depthwise' in param.name or 'logit' in param.name:
continue continue
if model_type.lower() == "deeplabv3p_mobilenetv3_large_x1_0_ssld":
if param.name in params_not_prune:
continue
prune_names.append(param.name) prune_names.append(param.name)
params_not_prune = [ params_not_prune = [
'xception_{}/exit_flow/block2/separable_conv3/pointwise/weights'. 'xception_{}/exit_flow/block2/separable_conv3/pointwise/weights'.
......
...@@ -42,7 +42,7 @@ def visualize(model, sensitivities_file, save_dir='./'): ...@@ -42,7 +42,7 @@ def visualize(model, sensitivities_file, save_dir='./'):
y = list() y = list()
for loss_thresh in tqdm.tqdm(list(np.arange(0.05, 1, 0.05))): for loss_thresh in tqdm.tqdm(list(np.arange(0.05, 1, 0.05))):
prune_ratio = 1 - cal_model_size( prune_ratio = 1 - cal_model_size(
program, place, sensitivities_file, eval_metric_loss=loss_thresh) program, place, sensitivities_file, eval_metric_loss=loss_thresh, scope=model.scope)
x.append(prune_ratio) x.append(prune_ratio)
y.append(loss_thresh) y.append(loss_thresh)
plt.plot(x, y, color='green', linewidth=0.5, marker='o', markersize=3) plt.plot(x, y, color='green', linewidth=0.5, marker='o', markersize=3)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册