提交 76f40469 编写于 作者: L LDOUBLEV

revert prune

上级 b79fee11
...@@ -24,14 +24,6 @@ sys.path.append(__dir__) ...@@ -24,14 +24,6 @@ sys.path.append(__dir__)
sys.path.append(os.path.join(__dir__, '..', '..', '..')) sys.path.append(os.path.join(__dir__, '..', '..', '..'))
sys.path.append(os.path.join(__dir__, '..', '..', '..', 'tools')) sys.path.append(os.path.join(__dir__, '..', '..', '..', 'tools'))
import json
import cv2
import paddle
from paddle import fluid
import paddleslim as slim
from copy import deepcopy
from tools import program
import paddle import paddle
import paddle.distributed as dist import paddle.distributed as dist
from ppocr.data import build_dataloader from ppocr.data import build_dataloader
...@@ -46,27 +38,13 @@ import tools.program as program ...@@ -46,27 +38,13 @@ import tools.program as program
dist.get_world_size() dist.get_world_size()
def get_pruned_params(parameters, mode="det"): def get_pruned_params(parameters):
if mode == "det":
skip_prune_params = [
"conv2d_56.w_0", "conv2d_54.w_0", "conv2d_51.w_0",
"conv_last_weights", "conv14_linear_weights",
"conv13_expand_weights", "conv12_linear_weights",
"conv12_expand_weights", "conv7_expand_weights",
"conv8_expand_weights", "conv8_linear_weights",
"conv5_linear_weights", "conv5_expand_weights",
"conv3_linear_weights"
]
skip_prune_params = skip_prune_params + ['conv2d_53.w_0']
else:
skip_prune_params = None
params = [] params = []
for param in parameters: for param in parameters:
if len( if len(
param.shape param.shape
) == 4 and 'depthwise' not in param.name and 'transpose' not in param.name and "conv2d_57" not in param.name and "conv2d_56" not in param.name: ) == 4 and 'depthwise' not in param.name and 'transpose' not in param.name and "conv2d_57" not in param.name and "conv2d_56" not in param.name:
if param.name not in skip_prune_params:
params.append(param.name) params.append(param.name)
return params return params
...@@ -118,6 +96,11 @@ def main(config, device, logger, vdl_writer): ...@@ -118,6 +96,11 @@ def main(config, device, logger, vdl_writer):
# load pretrain model # load pretrain model
pre_best_model_dict = init_model(config, model, logger, optimizer) pre_best_model_dict = init_model(config, model, logger, optimizer)
logger.info('train dataloader has {} iters, valid dataloader has {} iters'.
format(len(train_dataloader), len(valid_dataloader)))
# build metric
eval_class = build_metric(config['Metric'])
logger.info('train dataloader has {} iters, valid dataloader has {} iters'. logger.info('train dataloader has {} iters, valid dataloader has {} iters'.
format(len(train_dataloader), len(valid_dataloader))) format(len(train_dataloader), len(valid_dataloader)))
...@@ -127,20 +110,22 @@ def main(config, device, logger, vdl_writer): ...@@ -127,20 +110,22 @@ def main(config, device, logger, vdl_writer):
logger.info(f"metric['hmean']: {metric['hmean']}") logger.info(f"metric['hmean']: {metric['hmean']}")
return metric['hmean'] return metric['hmean']
pruner.sensitive( params_sensitive = pruner.sensitive(
eval_func=eval_fn, eval_func=eval_fn,
sen_file="./sen.pickle", sen_file="./sen.pickle",
skip_vars=[ skip_vars=[
"conv2d_57.w_0", "conv2d_transpose_2.w_0", "conv2d_transpose_3.w_0" "conv2d_57.w_0", "conv2d_transpose_2.w_0", "conv2d_transpose_3.w_0"
]) ])
params = get_pruned_params(model.parameters()) logger.info(
ratios = {} "The sensitivity analysis results of model parameters saved in sen.pickle"
# set the prune ratio is 0.2 )
for param in params: # calculate pruned params's ratio
ratios[param] = 0.2 params_sensitive = pruner._get_ratios_by_loss(params_sensitive, loss=0.02)
for key in params_sensitive.keys():
logger.info(f"{key}, {params_sensitive[key]}")
plan = pruner.prune_vars(ratios, [0]) plan = pruner.prune_vars(params_sensitive, [0])
for param in model.parameters(): for param in model.parameters():
if ("weights" in param.name and "conv" in param.name) or ( if ("weights" in param.name and "conv" in param.name) or (
"w_0" in param.name and "conv2d" in param.name): "w_0" in param.name and "conv2d" in param.name):
...@@ -150,6 +135,7 @@ def main(config, device, logger, vdl_writer): ...@@ -150,6 +135,7 @@ def main(config, device, logger, vdl_writer):
logger.info(f"FLOPs after pruning: {flops}") logger.info(f"FLOPs after pruning: {flops}")
# start train # start train
program.train(config, train_dataloader, valid_dataloader, device, model, program.train(config, train_dataloader, valid_dataloader, device, model,
loss_class, optimizer, lr_scheduler, post_process_class, loss_class, optimizer, lr_scheduler, post_process_class,
eval_class, pre_best_model_dict, logger, vdl_writer) eval_class, pre_best_model_dict, logger, vdl_writer)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册