提交 76f40469 编写于 作者: L LDOUBLEV

revert prune

上级 b79fee11
......@@ -24,14 +24,6 @@ sys.path.append(__dir__)
sys.path.append(os.path.join(__dir__, '..', '..', '..'))
sys.path.append(os.path.join(__dir__, '..', '..', '..', 'tools'))
import json
import cv2
import paddle
from paddle import fluid
import paddleslim as slim
from copy import deepcopy
from tools import program
import paddle
import paddle.distributed as dist
from ppocr.data import build_dataloader
......@@ -46,27 +38,13 @@ import tools.program as program
dist.get_world_size()
def get_pruned_params(parameters, mode="det"):
if mode == "det":
skip_prune_params = [
"conv2d_56.w_0", "conv2d_54.w_0", "conv2d_51.w_0",
"conv_last_weights", "conv14_linear_weights",
"conv13_expand_weights", "conv12_linear_weights",
"conv12_expand_weights", "conv7_expand_weights",
"conv8_expand_weights", "conv8_linear_weights",
"conv5_linear_weights", "conv5_expand_weights",
"conv3_linear_weights"
]
skip_prune_params = skip_prune_params + ['conv2d_53.w_0']
else:
skip_prune_params = None
def get_pruned_params(parameters):
params = []
for param in parameters:
if len(
param.shape
) == 4 and 'depthwise' not in param.name and 'transpose' not in param.name and "conv2d_57" not in param.name and "conv2d_56" not in param.name:
if param.name not in skip_prune_params:
params.append(param.name)
return params
......@@ -118,6 +96,11 @@ def main(config, device, logger, vdl_writer):
# load pretrain model
pre_best_model_dict = init_model(config, model, logger, optimizer)
logger.info('train dataloader has {} iters, valid dataloader has {} iters'.
format(len(train_dataloader), len(valid_dataloader)))
# build metric
eval_class = build_metric(config['Metric'])
logger.info('train dataloader has {} iters, valid dataloader has {} iters'.
format(len(train_dataloader), len(valid_dataloader)))
......@@ -127,20 +110,22 @@ def main(config, device, logger, vdl_writer):
logger.info(f"metric['hmean']: {metric['hmean']}")
return metric['hmean']
pruner.sensitive(
params_sensitive = pruner.sensitive(
eval_func=eval_fn,
sen_file="./sen.pickle",
skip_vars=[
"conv2d_57.w_0", "conv2d_transpose_2.w_0", "conv2d_transpose_3.w_0"
])
params = get_pruned_params(model.parameters())
ratios = {}
# set the prune ratio is 0.2
for param in params:
ratios[param] = 0.2
logger.info(
"The sensitivity analysis results of model parameters saved in sen.pickle"
)
# calculate pruned params's ratio
params_sensitive = pruner._get_ratios_by_loss(params_sensitive, loss=0.02)
for key in params_sensitive.keys():
logger.info(f"{key}, {params_sensitive[key]}")
plan = pruner.prune_vars(ratios, [0])
plan = pruner.prune_vars(params_sensitive, [0])
for param in model.parameters():
if ("weights" in param.name and "conv" in param.name) or (
"w_0" in param.name and "conv2d" in param.name):
......@@ -150,6 +135,7 @@ def main(config, device, logger, vdl_writer):
logger.info(f"FLOPs after pruning: {flops}")
# start train
program.train(config, train_dataloader, valid_dataloader, device, model,
loss_class, optimizer, lr_scheduler, post_process_class,
eval_class, pre_best_model_dict, logger, vdl_writer)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册