提交 fb8e883f 编写于 作者: L LDOUBLEV

refine deploy slim

上级 c64e235a
...@@ -24,6 +24,14 @@ sys.path.append(__dir__) ...@@ -24,6 +24,14 @@ sys.path.append(__dir__)
sys.path.append(os.path.join(__dir__, '..', '..', '..')) sys.path.append(os.path.join(__dir__, '..', '..', '..'))
sys.path.append(os.path.join(__dir__, '..', '..', '..', 'tools')) sys.path.append(os.path.join(__dir__, '..', '..', '..', 'tools'))
import json
import cv2
import paddle
from paddle import fluid
import paddleslim as slim
from copy import deepcopy
from tools import program
import paddle import paddle
import paddle.distributed as dist import paddle.distributed as dist
from ppocr.data import build_dataloader from ppocr.data import build_dataloader
...@@ -38,14 +46,28 @@ import tools.program as program ...@@ -38,14 +46,28 @@ import tools.program as program
dist.get_world_size() dist.get_world_size()
def get_pruned_params(parameters): def get_pruned_params(parameters, mode="det"):
if mode == "det":
skip_prune_params = [
"conv2d_56.w_0", "conv2d_54.w_0", "conv2d_51.w_0",
"conv_last_weights", "conv14_linear_weights",
"conv13_expand_weights", "conv12_linear_weights",
"conv12_expand_weights", "conv7_expand_weights",
"conv8_expand_weights", "conv8_linear_weights",
"conv5_linear_weights", "conv5_expand_weights",
"conv3_linear_weights"
]
skip_prune_params = skip_prune_params + ['conv2d_53.w_0']
else:
skip_prune_params = None
params = [] params = []
for param in parameters: for param in parameters:
if len( if len(
param.shape param.shape
) == 4 and 'depthwise' not in param.name and 'transpose' not in param.name and "conv2d_57" not in param.name and "conv2d_56" not in param.name: ) == 4 and 'depthwise' not in param.name and 'transpose' not in param.name and "conv2d_57" not in param.name and "conv2d_56" not in param.name:
params.append(param.name) if param.name not in skip_prune_params:
params.append(param.name)
return params return params
...@@ -75,7 +97,7 @@ def main(config, device, logger, vdl_writer): ...@@ -75,7 +97,7 @@ def main(config, device, logger, vdl_writer):
model = build_model(config['Architecture']) model = build_model(config['Architecture'])
flops = paddle.flops(model, [1, 3, 640, 640]) flops = paddle.flops(model, [1, 3, 640, 640])
logger.info(f"FLOPs before pruning: {flops}") print(f"FLOPs before pruning: {flops}")
from paddleslim.dygraph import FPGMFilterPruner from paddleslim.dygraph import FPGMFilterPruner
model.train() model.train()
...@@ -96,11 +118,6 @@ def main(config, device, logger, vdl_writer): ...@@ -96,11 +118,6 @@ def main(config, device, logger, vdl_writer):
# load pretrain model # load pretrain model
pre_best_model_dict = init_model(config, model, logger, optimizer) pre_best_model_dict = init_model(config, model, logger, optimizer)
logger.info('train dataloader has {} iters, valid dataloader has {} iters'.
format(len(train_dataloader), len(valid_dataloader)))
# build metric
eval_class = build_metric(config['Metric'])
logger.info('train dataloader has {} iters, valid dataloader has {} iters'. logger.info('train dataloader has {} iters, valid dataloader has {} iters'.
format(len(train_dataloader), len(valid_dataloader))) format(len(train_dataloader), len(valid_dataloader)))
...@@ -110,32 +127,29 @@ def main(config, device, logger, vdl_writer): ...@@ -110,32 +127,29 @@ def main(config, device, logger, vdl_writer):
logger.info(f"metric['hmean']: {metric['hmean']}") logger.info(f"metric['hmean']: {metric['hmean']}")
return metric['hmean'] return metric['hmean']
params_sensitive = pruner.sensitive( pruner.sensitive(
eval_func=eval_fn, eval_func=eval_fn,
sen_file="./sen.pickle", sen_file="./sen.pickle",
skip_vars=[ skip_vars=[
"conv2d_57.w_0", "conv2d_transpose_2.w_0", "conv2d_transpose_3.w_0" "conv2d_57.w_0", "conv2d_transpose_2.w_0", "conv2d_transpose_3.w_0"
]) ])
logger.info( params = get_pruned_params(model.parameters())
"The sensitivity analysis results of model parameters saved in sen.pickle" ratios = {}
) # set the prune ratio is 0.2
# calculate pruned params's ratio for param in params:
params_sensitive = pruner._get_ratios_by_loss(params_sensitive, loss=0.02) ratios[param] = 0.2
for key in params_sensitive.keys():
logger.info(f"{key}, {params_sensitive[key]}")
plan = pruner.prune_vars(params_sensitive, [0]) plan = pruner.prune_vars(ratios, [0])
for param in model.parameters(): for param in model.parameters():
if ("weights" in param.name and "conv" in param.name) or ( if ("weights" in param.name and "conv" in param.name) or (
"w_0" in param.name and "conv2d" in param.name): "w_0" in param.name and "conv2d" in param.name):
logger.info(f"{param.name}: {param.shape}") print(f"{param.name}: {param.shape}")
flops = paddle.flops(model, [1, 3, 640, 640]) flops = paddle.flops(model, [1, 3, 640, 640])
logger.info(f"FLOPs after pruning: {flops}") print(f"FLOPs after pruning: {flops}")
# start train # start train
program.train(config, train_dataloader, valid_dataloader, device, model, program.train(config, train_dataloader, valid_dataloader, device, model,
loss_class, optimizer, lr_scheduler, post_process_class, loss_class, optimizer, lr_scheduler, post_process_class,
eval_class, pre_best_model_dict, logger, vdl_writer) eval_class, pre_best_model_dict, logger, vdl_writer)
......
...@@ -112,10 +112,6 @@ def main(config, device, logger, vdl_writer): ...@@ -112,10 +112,6 @@ def main(config, device, logger, vdl_writer):
config['Architecture']["Head"]['out_channels'] = char_num config['Architecture']["Head"]['out_channels'] = char_num
model = build_model(config['Architecture']) model = build_model(config['Architecture'])
# prepare to quant
quanter = QAT(config=quant_config, act_preprocess=PACT)
quanter.quantize(model)
if config['Global']['distributed']: if config['Global']['distributed']:
model = paddle.DataParallel(model) model = paddle.DataParallel(model)
...@@ -136,31 +132,15 @@ def main(config, device, logger, vdl_writer): ...@@ -136,31 +132,15 @@ def main(config, device, logger, vdl_writer):
logger.info('train dataloader has {} iters, valid dataloader has {} iters'. logger.info('train dataloader has {} iters, valid dataloader has {} iters'.
format(len(train_dataloader), len(valid_dataloader))) format(len(train_dataloader), len(valid_dataloader)))
quanter = QAT(config=quant_config, act_preprocess=PACT)
quanter.quantize(model)
# start train # start train
program.train(config, train_dataloader, valid_dataloader, device, model, program.train(config, train_dataloader, valid_dataloader, device, model,
loss_class, optimizer, lr_scheduler, post_process_class, loss_class, optimizer, lr_scheduler, post_process_class,
eval_class, pre_best_model_dict, logger, vdl_writer) eval_class, pre_best_model_dict, logger, vdl_writer)
def test_reader(config, device, logger):
loader = build_dataloader(config, 'Train', device, logger)
import time
starttime = time.time()
count = 0
try:
for data in loader():
count += 1
if count % 1 == 0:
batch_time = time.time() - starttime
starttime = time.time()
logger.info("reader: {}, {}, {}".format(
count, len(data[0]), batch_time))
except Exception as e:
logger.info(e)
logger.info("finish reader: {}, Success!".format(count))
if __name__ == '__main__': if __name__ == '__main__':
config, device, logger, vdl_writer = program.preprocess(is_train=True) config, device, logger, vdl_writer = program.preprocess(is_train=True)
main(config, device, logger, vdl_writer) main(config, device, logger, vdl_writer)
# test_reader(config, device, logger)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册