提交 20f001d3 编写于 作者: 文幕地方's avatar 文幕地方

add rec fpgm

上级 7c15d1b7
...@@ -52,12 +52,17 @@ def main(config, device, logger, vdl_writer): ...@@ -52,12 +52,17 @@ def main(config, device, logger, vdl_writer):
config['Architecture']["Head"]['out_channels'] = char_num config['Architecture']["Head"]['out_channels'] = char_num
model = build_model(config['Architecture']) model = build_model(config['Architecture'])
flops = paddle.flops(model, [1, 3, 640, 640]) if config['Architecture']['model_type'] == 'det':
logger.info(f"FLOPs before pruning: {flops}") input_shape = [1, 3, 640, 640]
elif config['Architecture']['model_type'] == 'rec':
input_shape = [1, 3, 32, 320]
flops = paddle.flops(model, input_shape)
logger.info("FLOPs before pruning: {}".format(flops))
from paddleslim.dygraph import FPGMFilterPruner from paddleslim.dygraph import FPGMFilterPruner
model.train() model.train()
pruner = FPGMFilterPruner(model, [1, 3, 640, 640]) pruner = FPGMFilterPruner(model, input_shape)
# build metric # build metric
eval_class = build_metric(config['Metric']) eval_class = build_metric(config['Metric'])
...@@ -65,8 +70,13 @@ def main(config, device, logger, vdl_writer): ...@@ -65,8 +70,13 @@ def main(config, device, logger, vdl_writer):
def eval_fn(): def eval_fn():
metric = program.eval(model, valid_dataloader, post_process_class, metric = program.eval(model, valid_dataloader, post_process_class,
eval_class) eval_class)
logger.info(f"metric['hmean']: {metric['hmean']}") if config['Architecture']['model_type'] == 'det':
return metric['hmean'] main_indicator = 'hmean'
else:
main_indicator = 'acc'
logger.info("metric[{}]: {}".format(main_indicator, metric[
main_indicator]))
return metric[main_indicator]
params_sensitive = pruner.sensitive( params_sensitive = pruner.sensitive(
eval_func=eval_fn, eval_func=eval_fn,
...@@ -81,18 +91,22 @@ def main(config, device, logger, vdl_writer): ...@@ -81,18 +91,22 @@ def main(config, device, logger, vdl_writer):
# calculate pruned params's ratio # calculate pruned params's ratio
params_sensitive = pruner._get_ratios_by_loss(params_sensitive, loss=0.02) params_sensitive = pruner._get_ratios_by_loss(params_sensitive, loss=0.02)
for key in params_sensitive.keys(): for key in params_sensitive.keys():
logger.info(f"{key}, {params_sensitive[key]}") logger.info("{}, {}".format(key, params_sensitive[key]))
plan = pruner.prune_vars(params_sensitive, [0]) plan = pruner.prune_vars(params_sensitive, [0])
flops = paddle.flops(model, [1, 3, 640, 640]) flops = paddle.flops(model, input_shape)
logger.info(f"FLOPs after pruning: {flops}") logger.info("FLOPs after pruning: {}".format(flops))
# load pretrain model # load pretrain model
load_model(config, model) load_model(config, model)
metric = program.eval(model, valid_dataloader, post_process_class, metric = program.eval(model, valid_dataloader, post_process_class,
eval_class) eval_class)
logger.info(f"metric['hmean']: {metric['hmean']}") if config['Architecture']['model_type'] == 'det':
main_indicator = 'hmean'
else:
main_indicator = 'acc'
logger.info("metric['']: {}".format(main_indicator, metric[main_indicator]))
# start export model # start export model
from paddle.jit import to_static from paddle.jit import to_static
......
...@@ -73,13 +73,18 @@ def main(config, device, logger, vdl_writer): ...@@ -73,13 +73,18 @@ def main(config, device, logger, vdl_writer):
char_num = len(getattr(post_process_class, 'character')) char_num = len(getattr(post_process_class, 'character'))
config['Architecture']["Head"]['out_channels'] = char_num config['Architecture']["Head"]['out_channels'] = char_num
model = build_model(config['Architecture']) model = build_model(config['Architecture'])
if config['Architecture']['model_type'] == 'det':
input_shape = [1, 3, 640, 640]
elif config['Architecture']['model_type'] == 'rec':
input_shape = [1, 3, 32, 320]
flops = paddle.flops(model, input_shape)
flops = paddle.flops(model, [1, 3, 640, 640])
logger.info("FLOPs before pruning: {}".format(flops)) logger.info("FLOPs before pruning: {}".format(flops))
from paddleslim.dygraph import FPGMFilterPruner from paddleslim.dygraph import FPGMFilterPruner
model.train() model.train()
pruner = FPGMFilterPruner(model, [1, 3, 640, 640])
pruner = FPGMFilterPruner(model, input_shape)
# build loss # build loss
loss_class = build_loss(config['Loss']) loss_class = build_loss(config['Loss'])
...@@ -107,8 +112,14 @@ def main(config, device, logger, vdl_writer): ...@@ -107,8 +112,14 @@ def main(config, device, logger, vdl_writer):
def eval_fn(): def eval_fn():
metric = program.eval(model, valid_dataloader, post_process_class, metric = program.eval(model, valid_dataloader, post_process_class,
eval_class, False) eval_class, False)
logger.info("metric['hmean']: {}".format(metric['hmean'])) if config['Architecture']['model_type'] == 'det':
return metric['hmean'] main_indicator = 'hmean'
else:
main_indicator = 'acc'
logger.info("metric[{}]: {}".format(main_indicator, metric[
main_indicator]))
return metric[main_indicator]
run_sensitive_analysis = False run_sensitive_analysis = False
""" """
...@@ -149,7 +160,7 @@ def main(config, device, logger, vdl_writer): ...@@ -149,7 +160,7 @@ def main(config, device, logger, vdl_writer):
plan = pruner.prune_vars(params_sensitive, [0]) plan = pruner.prune_vars(params_sensitive, [0])
flops = paddle.flops(model, [1, 3, 640, 640]) flops = paddle.flops(model, input_shape)
logger.info("FLOPs after pruning: {}".format(flops)) logger.info("FLOPs after pruning: {}".format(flops))
# start train # start train
......
...@@ -26,8 +26,10 @@ class MobileNetV3(nn.Layer): ...@@ -26,8 +26,10 @@ class MobileNetV3(nn.Layer):
scale=0.5, scale=0.5,
large_stride=None, large_stride=None,
small_stride=None, small_stride=None,
disable_se=False,
**kwargs): **kwargs):
super(MobileNetV3, self).__init__() super(MobileNetV3, self).__init__()
self.disable_se = disable_se
if small_stride is None: if small_stride is None:
small_stride = [2, 2, 2, 2] small_stride = [2, 2, 2, 2]
if large_stride is None: if large_stride is None:
...@@ -101,6 +103,7 @@ class MobileNetV3(nn.Layer): ...@@ -101,6 +103,7 @@ class MobileNetV3(nn.Layer):
block_list = [] block_list = []
inplanes = make_divisible(inplanes * scale) inplanes = make_divisible(inplanes * scale)
for (k, exp, c, se, nl, s) in cfg: for (k, exp, c, se, nl, s) in cfg:
se = se and not self.disable_se
block_list.append( block_list.append(
ResidualUnit( ResidualUnit(
in_channels=inplanes, in_channels=inplanes,
......
Global:
use_gpu: true
epoch_num: 500
log_smooth_window: 20
print_batch_step: 10
save_model_dir: ./output/rec_chinese_lite_v2.0
save_epoch_step: 3
# evaluation is run every 5000 iterations after the 4000th iteration
eval_batch_step: [0, 2000]
cal_metric_during_train: True
pretrained_model:
checkpoints:
save_inference_dir:
use_visualdl: False
infer_img: doc/imgs_words/ch/word_1.jpg
# for data or label process
character_dict_path: ppocr/utils/ppocr_keys_v1.txt
max_text_length: 25
infer_mode: False
use_space_char: True
save_res_path: ./output/rec/predicts_chinese_lite_v2.0.txt
Optimizer:
name: Adam
beta1: 0.9
beta2: 0.999
lr:
name: Cosine
learning_rate: 0.001
regularizer:
name: 'L2'
factor: 0.00001
Architecture:
model_type: rec
algorithm: CRNN
Transform:
Backbone:
name: MobileNetV3
scale: 0.5
model_name: small
small_stride: [1, 2, 2, 2]
disable_se: True
Neck:
name: SequenceEncoder
encoder_type: rnn
hidden_size: 48
Head:
name: CTCHead
fc_decay: 0.00001
Loss:
name: CTCLoss
PostProcess:
name: CTCLabelDecode
Metric:
name: RecMetric
main_indicator: acc
Train:
dataset:
name: SimpleDataSet
data_dir: train_data/ic15_data
label_file_list: ["train_data/ic15_data/rec_gt_train.txt"]
transforms:
- DecodeImage: # load image
img_mode: BGR
channel_first: False
- RecAug:
- CTCLabelEncode: # Class handling label
- RecResizeImg:
image_shape: [3, 32, 320]
- KeepKeys:
keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
loader:
shuffle: True
batch_size_per_card: 256
drop_last: True
num_workers: 8
Eval:
dataset:
name: SimpleDataSet
data_dir: train_data/ic15_data
label_file_list: ["train_data/ic15_data/rec_gt_test.txt"]
transforms:
- DecodeImage: # load image
img_mode: BGR
channel_first: False
- CTCLabelEncode: # Class handling label
- RecResizeImg:
image_shape: [3, 32, 320]
- KeepKeys:
keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
loader:
shuffle: False
drop_last: False
batch_size_per_card: 256
num_workers: 8
===========================train_params===========================
model_name:ch_ppocr_mobile_v2.0_rec_FPGM
python:python3.7
gpu_list:0
Global.use_gpu:True|True
Global.auto_cast:null
Global.epoch_num:lite_train_lite_infer=1|whole_train_whole_infer=300
Global.save_model_dir:./output/
Train.loader.batch_size_per_card:lite_train_lite_infer=128|whole_train_whole_infer=128
Global.pretrained_model:null
train_model_name:latest
train_infer_img_dir:./train_data/ic15_data/test/word_1.png
null:null
##
trainer:fpgm_train
norm_train:null
pact_train:null
fpgm_train:deploy/slim/prune/sensitivity_anal.py -c test_tipc/configs/ch_ppocr_mobile_v2.0_rec_FPGM/rec_chinese_lite_train_v2.0.yml -o Global.pretrained_model=./pretrain_models/ch_ppocr_mobile_v2.0_rec_train/best_accuracy
distill_train:null
null:null
null:null
##
===========================eval_params===========================
eval:null
null:null
##
===========================infer_params===========================
Global.save_inference_dir:./output/
Global.pretrained_model:
norm_export:null
quant_export:null
fpgm_export:deploy/slim/prune/export_prune_model.py -c test_tipc/configs/ch_ppocr_mobile_v2.0_rec_FPGM/rec_chinese_lite_train_v2.0.yml -o
distill_export:null
export1:null
export2:null
inference_dir:null
train_model:null
infer_export:null
infer_quant:False
inference:tools/infer/predict_rec.py
--use_gpu:True|False
--enable_mkldnn:True|False
--cpu_threads:1|6
--rec_batch_num:1
--use_tensorrt:False|True
--precision:fp32|int8
--rec_model_dir:
--image_dir:./inference/rec_inference
null:null
--benchmark:True
null:null
\ No newline at end of file
...@@ -61,6 +61,10 @@ if [ ${MODE} = "lite_train_lite_infer" ];then ...@@ -61,6 +61,10 @@ if [ ${MODE} = "lite_train_lite_infer" ];then
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_db_v2.0_train.tar --no-check-certificate wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_db_v2.0_train.tar --no-check-certificate
cd ./inference/ && tar xf det_r50_vd_db_v2.0_train.tar && cd ../ cd ./inference/ && tar xf det_r50_vd_db_v2.0_train.tar && cd ../
fi fi
if [ ${model_name} == "ch_ppocr_mobile_v2.0_rec_FPGM" ]; then
wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_train.tar --no-check-certificate
cd ./pretrain_models/ && tar xf ch_ppocr_mobile_v2.0_rec_train.tar && cd ../
fi
elif [ ${MODE} = "whole_train_whole_infer" ];then elif [ ${MODE} = "whole_train_whole_infer" ];then
wget -nc -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MobileNetV3_large_x0_5_pretrained.pdparams --no-check-certificate wget -nc -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MobileNetV3_large_x0_5_pretrained.pdparams --no-check-certificate
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册