提交 9cef2409 编写于 作者: 风为何不回来's avatar 风为何不回来

add and update sr model

上级 0cdfc525
...@@ -80,5 +80,5 @@ Eval: ...@@ -80,5 +80,5 @@ Eval:
shuffle: False shuffle: False
drop_last: False drop_last: False
batch_size_per_card: 16 batch_size_per_card: 16
num_workers: 0 num_workers: 4
...@@ -25,6 +25,8 @@ from .det_east_loss import EASTLoss ...@@ -25,6 +25,8 @@ from .det_east_loss import EASTLoss
from .det_sast_loss import SASTLoss from .det_sast_loss import SASTLoss
from .det_pse_loss import PSELoss from .det_pse_loss import PSELoss
from .det_fce_loss import FCELoss from .det_fce_loss import FCELoss
from .det_ct_loss import CTLoss
from .det_drrg_loss import DRRGLoss
# rec loss # rec loss
from .rec_ctc_loss import CTCLoss from .rec_ctc_loss import CTCLoss
...@@ -37,6 +39,7 @@ from .rec_pren_loss import PRENLoss ...@@ -37,6 +39,7 @@ from .rec_pren_loss import PRENLoss
from .rec_multi_loss import MultiLoss from .rec_multi_loss import MultiLoss
from .rec_vl_loss import VLLoss from .rec_vl_loss import VLLoss
from .rec_spin_att_loss import SPINAttentionLoss from .rec_spin_att_loss import SPINAttentionLoss
from .rec_rfl_loss import RFLLoss
# cls loss # cls loss
from .cls_loss import ClsLoss from .cls_loss import ClsLoss
...@@ -69,7 +72,7 @@ def build_loss(config): ...@@ -69,7 +72,7 @@ def build_loss(config):
'CELoss', 'TableAttentionLoss', 'SARLoss', 'AsterLoss', 'SDMGRLoss', 'CELoss', 'TableAttentionLoss', 'SARLoss', 'AsterLoss', 'SDMGRLoss',
'VQASerTokenLayoutLMLoss', 'LossFromOutput', 'PRENLoss', 'MultiLoss', 'VQASerTokenLayoutLMLoss', 'LossFromOutput', 'PRENLoss', 'MultiLoss',
'TableMasterLoss', 'SPINAttentionLoss', 'VLLoss', 'StrokeFocusLoss', 'TableMasterLoss', 'SPINAttentionLoss', 'VLLoss', 'StrokeFocusLoss',
'SLALoss', 'TelescopeLoss' 'SLALoss', 'CTLoss', 'RFLLoss', 'DRRGLoss', 'TelescopeLoss'
] ]
config = copy.deepcopy(config) config = copy.deepcopy(config)
module_name = config.pop('name') module_name = config.pop('name')
......
Global:
use_gpu: true
epoch_num: 2
log_smooth_window: 20
print_batch_step: 10
save_model_dir: ./output/sr/sr_telescope/
save_epoch_step: 3
# evaluation is run every 2000 iterations
eval_batch_step: [0, 1000]
cal_metric_during_train: False
pretrained_model:
checkpoints:
save_inference_dir: ./output/sr/sr_telescope/infer
use_visualdl: False
infer_img: doc/imgs_words_en/word_52.png
# for data or label process
character_dict_path:
max_text_length: 100
infer_mode: False
use_space_char: False
save_res_path: ./output/sr/predicts_telescope.txt
Optimizer:
name: Adam
beta1: 0.5
beta2: 0.999
clip_norm: 0.25
lr:
learning_rate: 0.0001
Architecture:
model_type: sr
algorithm: Telescope
Transform:
name: TBSRN
STN: True
infer_mode: False
Loss:
name: TelescopeLoss
confuse_dict_path: ./ppocr/utils/dict/confuse.pkl
PostProcess:
name: None
Metric:
name: SRMetric
main_indicator: all
Train:
dataset:
name: LMDBDataSetSR
data_dir: ./train_data/TextZoom/train
transforms:
- SRResize:
imgH: 32
imgW: 128
down_sample_scale: 2
- KeepKeys:
keep_keys: ['img_lr', 'img_hr', 'label'] # dataloader will return list in this order
loader:
shuffle: False
batch_size_per_card: 16
drop_last: True
num_workers: 4
Eval:
dataset:
name: LMDBDataSetSR
data_dir: ./train_data/TextZoom/test
transforms:
- SRResize:
imgH: 32
imgW: 128
down_sample_scale: 2
- KeepKeys:
keep_keys: ['img_lr', 'img_hr', 'label'] # dataloader will return list in this order
loader:
shuffle: False
drop_last: False
batch_size_per_card: 16
num_workers: 4
===========================train_params===========================
model_name:sr_telescope
python:python3.7
gpu_list:0|0,1
Global.use_gpu:True|True
Global.auto_cast:null
Global.epoch_num:lite_train_lite_infer=2|whole_train_whole_infer=300
Global.save_model_dir:./output/
Train.loader.batch_size_per_card:lite_train_lite_infer=16|whole_train_whole_infer=16
Global.pretrained_model:null
train_model_name:latest
train_infer_img_dir:./inference/sr_inference
null:null
##
trainer:norm_train
norm_train:tools/train.py -c test_tipc/configs/sr_telescope/sr_telescope.yml -o
pact_train:null
fpgm_train:null
distill_train:null
null:null
null:null
##
===========================eval_params===========================
eval:tools/eval.py -c test_tipc/configs/sr_telescope/sr_telescope.yml -o
null:null
##
===========================infer_params===========================
Global.save_inference_dir:./output/
Global.checkpoints:
norm_export:tools/export_model.py -c test_tipc/configs/sr_telescope/sr_telescope.yml -o
quant_export:null
fpgm_export:null
distill_export:null
export1:null
export2:null
##
train_model:./inference/sr_telescope_train/best_accuracy
infer_export:tools/export_model.py -c test_tipc/configs/sr_telescope/sr_telescope.yml -o
infer_quant:False
inference:tools/infer/predict_sr.py --sr_image_shape="1,32,128" --rec_algorithm="Telescope" --min_subgraph_size=5
--use_gpu:True|False
--enable_mkldnn:False
--cpu_threads:6
--rec_batch_num:1
--use_tensorrt:False
--precision:fp32
--rec_model_dir:
--image_dir:./inference/sr_inference
--save_log_path:./test/output/
--benchmark:True
null:null
===========================infer_benchmark_params==========================
random_infer_input:[{float32,[1,32,128]}]
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册