diff --git a/configs/det/det_mv3_db.yml b/configs/det/det_mv3_db.yml index fc0c007da2057c48dbd81bbdfdbfa9397789d2c1..a997aa38fcd46ab58f531e15b62927bcfd6c1992 100644 --- a/configs/det/det_mv3_db.yml +++ b/configs/det/det_mv3_db.yml @@ -10,8 +10,8 @@ Global: # if pretrained_model is saved in static mode, load_static_weights must set to True load_static_weights: True cal_metric_during_train: False - pretrained_model: /home/zhoujun20/pretrain_models/MobileNetV3_large_x0_5_pretrained - checkpoints: #./output/det_db_0.001_DiceLoss_256_pp_config_2.0b_4gpu/best_accuracy + pretrained_model: ./pretrain_models/MobileNetV3_large_x0_5_pretrained + checkpoints: save_inference_dir: use_visualdl: True infer_img: doc/imgs_en/img_10.jpg @@ -22,9 +22,7 @@ Optimizer: beta1: 0.9 beta2: 0.999 learning_rate: -# name: Cosine lr: 0.001 -# warmup_epoch: 0 regularizer: name: 'L2' factor: 0 @@ -98,7 +96,7 @@ TRAIN: order: 'hwc' - ToCHWImage: - keepKeys: - keep_keys: ['image','threshold_map','threshold_mask','shrink_map','shrink_mask'] # dataloader将按照此顺序返回list + keep_keys: ['image','threshold_map','threshold_mask','shrink_map','shrink_mask'] # dataloader will return list in this order loader: shuffle: True drop_last: False diff --git a/configs/det/det_r50_vd_db.yml b/configs/det/det_r50_vd_db.yml index 57940926280d67356924f78514102982124b8564..a07273b4ae294164c0c5d8166ec602beade55259 100644 --- a/configs/det/det_r50_vd_db.yml +++ b/configs/det/det_r50_vd_db.yml @@ -3,15 +3,15 @@ Global: epoch_num: 1200 log_smooth_window: 20 print_batch_step: 2 - save_model_dir: ./output/20201015_r50/ + save_model_dir: ./output/det_r50_vd/ save_epoch_step: 1200 # evaluation is run every 5000 iterations after the 4000th iteration eval_batch_step: 8 # if pretrained_model is saved in static mode, load_static_weights must set to True load_static_weights: True cal_metric_during_train: False - pretrained_model: /home/zhoujun20/pretrain_models/ResNet50_vd_ssld_pretrained/ - checkpoints: #./output/det_db_0.001_DiceLoss_256_pp_config_2.0b_4gpu/best_accuracy + pretrained_model: ./pretrain_models/ResNet50_vd_ssld_pretrained/ + checkpoints: save_inference_dir: use_visualdl: True infer_img: doc/imgs_en/img_10.jpg @@ -22,9 +22,7 @@ Optimizer: beta1: 0.9 beta2: 0.999 learning_rate: -# name: Cosine lr: 0.001 -# warmup_epoch: 0 regularizer: name: 'L2' factor: 0 @@ -65,9 +63,9 @@ Metric: TRAIN: dataset: name: SimpleDataSet - data_dir: /home/zhoujun20/detection/ + data_dir: ./detection/ file_list: - - /home/zhoujun20/detection/train_icdar2015_label.txt # dataset1 + - ./detection/train_icdar2015_label.txt # dataset1 ratio_list: [1.0] transforms: - DecodeImage: # load image @@ -97,7 +95,7 @@ TRAIN: order: 'hwc' - ToCHWImage: - keepKeys: - keep_keys: ['image','threshold_map','threshold_mask','shrink_map','shrink_mask'] # dataloader将按照此顺序返回list + keep_keys: ['image','threshold_map','threshold_mask','shrink_map','shrink_mask'] # dataloader will return list in this order loader: shuffle: True drop_last: False @@ -107,9 +105,9 @@ TRAIN: EVAL: dataset: name: SimpleDataSet - data_dir: /home/zhoujun20/detection/ + data_dir: ./detection/ file_list: - - /home/zhoujun20/detection/test_icdar2015_label.txt + - ./detection/test_icdar2015_label.txt transforms: - DecodeImage: # load image img_mode: BGR diff --git a/configs/rec/rec_mv3_none_bilstm_ctc.yml b/configs/rec/rec_mv3_none_bilstm_ctc.yml index 7119e0e2bd5ee190025d30f37dcfbd25661b6b6c..1be7512c9d793b38b7d5c23ab4e55972e793c28b 100644 --- a/configs/rec/rec_mv3_none_bilstm_ctc.yml +++ b/configs/rec/rec_mv3_none_bilstm_ctc.yml @@ -3,7 +3,7 @@ Global: epoch_num: 500 log_smooth_window: 20 print_batch_step: 10 - save_model_dir: ./output/rec/test/ + save_model_dir: ./output/rec/mv3_none_bilstm_ctc/ save_epoch_step: 500 # evaluation is run every 5000 iterations after the 4000th iteration eval_batch_step: 127 @@ -11,7 +11,7 @@ Global: load_static_weights: True cal_metric_during_train: True pretrained_model: - checkpoints: #output/rec/rec_crnn/best_accuracy + checkpoints: save_inference_dir: use_visualdl: False infer_img: doc/imgs_words/ch/word_1.jpg @@ -29,9 +29,7 @@ Optimizer: beta1: 0.9 beta2: 0.999 learning_rate: - name: Cosine lr: 0.001 - warmup_epoch: 4 regularizer: name: 'L2' factor: 0.00001 @@ -66,9 +64,9 @@ Metric: TRAIN: dataset: name: SimpleDataSet - data_dir: /home/zhoujun20/rec + data_dir: ./rec file_list: - - /home/zhoujun20/rec/real_data.txt # dataset1 + - ./rec/train.txt # dataset1 ratio_list: [ 0.4,0.6 ] transforms: - DecodeImage: # load image @@ -79,7 +77,7 @@ TRAIN: - RecResizeImg: image_shape: [ 3,32,320 ] - keepKeys: - keep_keys: [ 'image','label','length' ] # dataloader将按照此顺序返回list + keep_keys: [ 'image','label','length' ] # dataloader will return list in this order loader: batch_size: 256 shuffle: True @@ -89,9 +87,9 @@ TRAIN: EVAL: dataset: name: SimpleDataSet - data_dir: /home/zhoujun20/rec + data_dir: ./rec file_list: - - /home/zhoujun20/rec/label_val_all.txt + - ./rec/val.txt transforms: - DecodeImage: # load image img_mode: BGR @@ -100,7 +98,7 @@ EVAL: - RecResizeImg: image_shape: [ 3,32,320 ] - keepKeys: - keep_keys: [ 'image','label','length' ] # dataloader将按照此顺序返回list + keep_keys: [ 'image','label','length' ] # dataloader will return list in this order loader: shuffle: False drop_last: False diff --git a/configs/rec/rec_mv3_none_bilstm_ctc_lmdb.yml b/configs/rec/rec_mv3_none_bilstm_ctc_lmdb.yml index 1887680ff50ef91d1b7be2e7e1940642dfa46c85..f917b0d8caa71bfa9bcef8d0e37df6f8ca163bf8 100644 --- a/configs/rec/rec_mv3_none_bilstm_ctc_lmdb.yml +++ b/configs/rec/rec_mv3_none_bilstm_ctc_lmdb.yml @@ -3,7 +3,7 @@ Global: epoch_num: 500 log_smooth_window: 20 print_batch_step: 1 - save_model_dir: ./output/rec/test/ + save_model_dir: ./output/rec/mv3_none_bilstm_ctc/ save_epoch_step: 500 # evaluation is run every 5000 iterations after the 4000th iteration eval_batch_step: 1016 @@ -11,13 +11,13 @@ Global: load_static_weights: True cal_metric_during_train: True pretrained_model: - checkpoints: #output/rec/rec_crnn/best_accuracy + checkpoints: save_inference_dir: use_visualdl: True infer_img: doc/imgs_words/ch/word_1.jpg # for data or label process max_text_length: 80 - character_dict_path: /home/zhoujun20/rec/lmdb/dict.txt + character_dict_path: ppocr/utils/ppocr_keys_v1.txt character_type: 'ch' use_space_char: True infer_mode: False @@ -29,9 +29,7 @@ Optimizer: beta1: 0.9 beta2: 0.999 learning_rate: - name: Cosine lr: 0.0005 - warmup_epoch: 1 regularizer: name: 'L2' factor: 0.00001 @@ -67,7 +65,7 @@ TRAIN: dataset: name: LMDBDateSet file_list: - - /home/zhoujun20/rec/lmdb/train # dataset1 + - ./rec/lmdb/train # dataset1 ratio_list: [ 0.4,0.6 ] transforms: - DecodeImage: # load image @@ -78,7 +76,7 @@ TRAIN: - RecResizeImg: image_shape: [ 3,32,320 ] - keepKeys: - keep_keys: [ 'image','label','length' ] # dataloader将按照此顺序返回list + keep_keys: [ 'image','label','length' ] # dataloader will return list in this order loader: batch_size: 256 shuffle: True @@ -89,7 +87,7 @@ EVAL: dataset: name: LMDBDateSet file_list: - - /home/zhoujun20/rec/lmdb/val + - ./rec/lmdb/val transforms: - DecodeImage: # load image img_mode: BGR @@ -98,7 +96,7 @@ EVAL: - RecResizeImg: image_shape: [ 3,32,320 ] - keepKeys: - keep_keys: [ 'image','label','length' ] # dataloader将按照此顺序返回list + keep_keys: [ 'image','label','length' ] # dataloader will return list in this order loader: shuffle: False drop_last: False diff --git a/configs/rec/rec_mv3_none_none_ctc_lmdb.yml b/configs/rec/rec_mv3_none_none_ctc_lmdb.yml index 413e1c3c315ae90fbc3f096dfbb7f8bb3a8f39e9..19997fd59973ec429da24085047b9578269aa91a 100644 --- a/configs/rec/rec_mv3_none_none_ctc_lmdb.yml +++ b/configs/rec/rec_mv3_none_none_ctc_lmdb.yml @@ -1,25 +1,25 @@ Global: use_gpu: false - epoch_num: 500 + epoch_num: 72 log_smooth_window: 20 - print_batch_step: 1 - save_model_dir: ./output/rec/test/ + print_batch_step: 10 + save_model_dir: ./output/rec/mv3_none_none_ctc/ save_epoch_step: 500 # evaluation is run every 5000 iterations after the 4000th iteration - eval_batch_step: 1016 + eval_batch_step: 2000 # if pretrained_model is saved in static mode, load_static_weights must set to True load_static_weights: True cal_metric_during_train: True pretrained_model: - checkpoints: #output/rec/rec_crnn/best_accuracy + checkpoints: save_inference_dir: use_visualdl: True infer_img: doc/imgs_words/ch/word_1.jpg # for data or label process - max_text_length: 80 - character_dict_path: /home/zhoujun20/rec/lmdb/dict.txt + max_text_length: 25 + character_dict_path: character_type: 'en' - use_space_char: True + use_space_char: False infer_mode: False use_tps: False @@ -29,9 +29,7 @@ Optimizer: beta1: 0.9 beta2: 0.999 learning_rate: - name: Cosine lr: 0.0005 - warmup_epoch: 1 regularizer: name: 'L2' factor: 0.00001 @@ -43,7 +41,7 @@ Architecture: Backbone: name: MobileNetV3 scale: 0.5 - model_name: small + model_name: large small_stride: [ 1, 2, 2, 2 ] Neck: name: SequenceEncoder @@ -66,7 +64,7 @@ TRAIN: dataset: name: LMDBDateSet file_list: - - /Users/zhoujun20/Downloads/evaluation_new # dataset1 + - ./rec/train # dataset1 ratio_list: [ 0.4,0.6 ] transforms: - DecodeImage: # load image @@ -75,9 +73,9 @@ TRAIN: - CTCLabelEncode: # Class handling label - RecAug: - RecResizeImg: - image_shape: [ 3,32,320 ] + image_shape: [ 3,32,100 ] - keepKeys: - keep_keys: [ 'image','label','length' ] # dataloader将按照此顺序返回list + keep_keys: [ 'image','label','length' ] # dataloader will return list in this order loader: batch_size: 256 shuffle: True @@ -88,16 +86,16 @@ EVAL: dataset: name: LMDBDateSet file_list: - - /home/zhoujun20/rec/lmdb/val + - ./rec/val/ transforms: - DecodeImage: # load image img_mode: BGR channel_first: False - CTCLabelEncode: # Class handling label - RecResizeImg: - image_shape: [ 3,32,320 ] + image_shape: [ 3,32,100 ] - keepKeys: - keep_keys: [ 'image','label','length' ] # dataloader将按照此顺序返回list + keep_keys: [ 'image','label','length' ] # dataloader will return list in this order loader: shuffle: False drop_last: False diff --git a/configs/rec/rec_r34_vd_none_bilstm_ctc.yml b/configs/rec/rec_r34_vd_none_bilstm_ctc.yml index e87115dfc947ede56fb1feaf24824c1535dcac76..36e3d1c81cb5e5ad744576dc6d454e8f31d965dc 100644 --- a/configs/rec/rec_r34_vd_none_bilstm_ctc.yml +++ b/configs/rec/rec_r34_vd_none_bilstm_ctc.yml @@ -3,7 +3,7 @@ Global: epoch_num: 500 log_smooth_window: 20 print_batch_step: 10 - save_model_dir: ./output/rec/test/ + save_model_dir: ./output/rec/res34_none_bilstm_ctc/ save_epoch_step: 500 # evaluation is run every 5000 iterations after the 4000th iteration eval_batch_step: 127 @@ -11,7 +11,7 @@ Global: load_static_weights: True cal_metric_during_train: True pretrained_model: - checkpoints: #output/rec/rec_crnn/best_accuracy + checkpoints: save_inference_dir: use_visualdl: False infer_img: doc/imgs_words/ch/word_1.jpg @@ -29,9 +29,7 @@ Optimizer: beta1: 0.9 beta2: 0.999 learning_rate: - name: Cosine lr: 0.001 - warmup_epoch: 4 regularizer: name: 'L2' factor: 0.00001 @@ -64,9 +62,9 @@ Metric: TRAIN: dataset: name: SimpleDataSet - data_dir: /home/zhoujun20/rec + data_dir: ./rec file_list: - - /home/zhoujun20/rec/real_data.txt # dataset1 + - ./rec/train.txt # dataset1 ratio_list: [ 0.4,0.6 ] transforms: - DecodeImage: # load image @@ -77,7 +75,7 @@ TRAIN: - RecResizeImg: image_shape: [ 3,32,320 ] - keepKeys: - keep_keys: [ 'image','label','length' ] # dataloader将按照此顺序返回list + keep_keys: [ 'image','label','length' ] # dataloader will return list in this order loader: batch_size: 256 shuffle: True @@ -87,9 +85,9 @@ TRAIN: EVAL: dataset: name: SimpleDataSet - data_dir: /home/zhoujun20/rec + data_dir: ./rec file_list: - - /home/zhoujun20/rec/label_val_all.txt + - ./rec/val.txt transforms: - DecodeImage: # load image img_mode: BGR @@ -98,7 +96,7 @@ EVAL: - RecResizeImg: image_shape: [ 3,32,320 ] - keepKeys: - keep_keys: [ 'image','label','length' ] # dataloader将按照此顺序返回list + keep_keys: [ 'image','label','length' ] # dataloader will return list in this order loader: shuffle: False drop_last: False diff --git a/configs/rec/rec_r34_vd_none_none_ctc.yml b/configs/rec/rec_r34_vd_none_none_ctc.yml new file mode 100644 index 0000000000000000000000000000000000000000..641e855b431e459536453275759c6a5f064c15fb --- /dev/null +++ b/configs/rec/rec_r34_vd_none_none_ctc.yml @@ -0,0 +1,103 @@ +Global: + use_gpu: false + epoch_num: 500 + log_smooth_window: 20 + print_batch_step: 10 + save_model_dir: ./output/rec/res34_none_none_ctc/ + save_epoch_step: 500 + # evaluation is run every 5000 iterations after the 4000th iteration + eval_batch_step: 127 + # if pretrained_model is saved in static mode, load_static_weights must set to True + load_static_weights: True + cal_metric_during_train: True + pretrained_model: + checkpoints: + save_inference_dir: + use_visualdl: False + infer_img: doc/imgs_words/ch/word_1.jpg + # for data or label process + max_text_length: 80 + character_dict_path: ppocr/utils/ppocr_keys_v1.txt + character_type: 'ch' + use_space_char: False + infer_mode: False + use_tps: False + + +Optimizer: + name: Adam + beta1: 0.9 + beta2: 0.999 + learning_rate: + lr: 0.001 + regularizer: + name: 'L2' + factor: 0.00001 + +Architecture: + type: rec + algorithm: CRNN + Transform: + Backbone: + name: ResNet + layers: 34 + Neck: + name: SequenceEncoder + encoder_type: reshape + Head: + name: CTC + fc_decay: 0.00001 + +Loss: + name: CTCLoss + +PostProcess: + name: CTCLabelDecode + +Metric: + name: RecMetric + main_indicator: acc + +TRAIN: + dataset: + name: SimpleDataSet + data_dir: ./rec + file_list: + - ./rec/train.txt # dataset1 + ratio_list: [ 0.4,0.6 ] + transforms: + - DecodeImage: # load image + img_mode: BGR + channel_first: False + - CTCLabelEncode: # Class handling label + - RecAug: + - RecResizeImg: + image_shape: [ 3,32,320 ] + - keepKeys: + keep_keys: [ 'image','label','length' ] # dataloader will return list in this order + loader: + batch_size: 256 + shuffle: True + drop_last: True + num_workers: 8 + +EVAL: + dataset: + name: SimpleDataSet + data_dir: ./rec + file_list: + - ./rec/val.txt + transforms: + - DecodeImage: # load image + img_mode: BGR + channel_first: False + - CTCLabelEncode: # Class handling label + - RecResizeImg: + image_shape: [ 3,32,320 ] + - keepKeys: + keep_keys: [ 'image','label','length' ] # dataloader will return list in this order + loader: + shuffle: False + drop_last: False + batch_size: 256 + num_workers: 8 diff --git a/ppocr/modeling/architectures/model.py b/ppocr/modeling/architectures/model.py index 5723beb7d05718fe40810f8f345b11265eff805a..222b08d6c84ac4edb0cfd822c2c88ce9218dcc10 100644 --- a/ppocr/modeling/architectures/model.py +++ b/ppocr/modeling/architectures/model.py @@ -21,7 +21,6 @@ __dir__ = os.path.dirname(os.path.abspath(__file__)) sys.path.append(__dir__) sys.path.append('/home/zhoujun20/PaddleOCR') -import paddle from paddle import nn from ppocr.modeling.transform import build_transform from ppocr.modeling.backbones import build_backbone @@ -72,12 +71,10 @@ class Model(nn.Layer): config['Neck']['in_channels'] = in_channels self.neck = build_neck(config['Neck']) in_channels = self.neck.out_channels - - # # build head, head is need for del, rec and cls + # # build head, head is need for det, rec and cls config["Head"]['in_channels'] = in_channels self.head = build_head(config["Head"]) - # @paddle.jit.to_static def forward(self, x): if self.use_transform: x = self.transform(x) @@ -85,41 +82,4 @@ class Model(nn.Layer): if self.use_neck: x = self.neck(x) x = self.head(x) - return x - - -def check_static(): - import numpy as np - from ppocr.utils.save_load import load_dygraph_pretrain - from ppocr.utils.logging import get_logger - from tools import program - - config = program.load_config('configs/rec/rec_r34_vd_none_bilstm_ctc.yml') - - logger = get_logger() - np.random.seed(0) - data = np.random.rand(1, 3, 32, 320).astype(np.float32) - paddle.disable_static() - - config['Architecture']['in_channels'] = 3 - config['Architecture']["Head"]['out_channels'] = 6624 - model = Model(config['Architecture']) - model.eval() - load_dygraph_pretrain( - model, - logger, - '/Users/zhoujun20/Desktop/code/PaddleOCR/cnn_ctc/cnn_ctc', - load_static_weights=True) - x = paddle.to_tensor(data) - y = model(x) - for y1 in y: - print(y1.shape) - - static_out = np.load( - '/Users/zhoujun20/Desktop/code/PaddleOCR/output/conv.npy') - diff = y.numpy() - static_out - print(y.shape, static_out.shape, diff.mean()) - - -if __name__ == '__main__': - check_static() + return x \ No newline at end of file diff --git a/ppocr/modeling/heads/rec_ctc_head.py b/ppocr/modeling/heads/rec_ctc_head.py index 8c7b904fed741d947e580611de3e9d8cb2f312f4..e96b96ad8168bb2d045f99458398ea75807a8b50 100755 --- a/ppocr/modeling/heads/rec_ctc_head.py +++ b/ppocr/modeling/heads/rec_ctc_head.py @@ -20,6 +20,7 @@ import math import paddle from paddle import ParamAttr, nn +from paddle.nn import functional as F def get_para_bias_attr(l2_decay, k, name): @@ -48,4 +49,6 @@ class CTC(nn.Layer): def forward(self, x, labels=None): predicts = self.fc(x) + if not self.training: + predicts = F.softmax(predicts, axis=2) return predicts diff --git a/tools/infer/utility.py b/tools/infer/utility.py index dab06349a7d740b317c1d6257c839766bba073bb..cfbec95605e261994555392bedb9f3abd26c31fe 100755 --- a/tools/infer/utility.py +++ b/tools/infer/utility.py @@ -14,11 +14,14 @@ import argparse import os +import sys import cv2 import numpy as np import json from PIL import Image, ImageDraw, ImageFont import math +from paddle.fluid.core import AnalysisConfig +from paddle.fluid.core import create_paddle_predictor def parse_args(): @@ -71,6 +74,59 @@ def parse_args(): return parser.parse_args() +def create_predictor(args, mode, logger): + if mode == "det": + model_dir = args.det_model_dir + elif mode == 'cls': + model_dir = args.cls_model_dir + else: + model_dir = args.rec_model_dir + + if model_dir is None: + logger.info("not find {} model file path {}".format(mode, model_dir)) + sys.exit(0) + model_file_path = model_dir + "/model" + params_file_path = model_dir + "/params" + if not os.path.exists(model_file_path): + logger.info("not find model file path {}".format(model_file_path)) + sys.exit(0) + if not os.path.exists(params_file_path): + logger.info("not find params file path {}".format(params_file_path)) + sys.exit(0) + + config = AnalysisConfig(model_file_path, params_file_path) + + if args.use_gpu: + config.enable_use_gpu(args.gpu_mem, 0) + else: + config.disable_gpu() + config.set_cpu_math_library_num_threads(6) + if args.enable_mkldnn: + # cache 10 different shapes for mkldnn to avoid memory leak + config.set_mkldnn_cache_capacity(10) + config.enable_mkldnn() + + # config.enable_memory_optim() + config.disable_glog_info() + + if args.use_zero_copy_run: + config.delete_pass("conv_transpose_eltwiseadd_bn_fuse_pass") + config.switch_use_feed_fetch_ops(False) + else: + config.switch_use_feed_fetch_ops(True) + + predictor = create_paddle_predictor(config) + input_names = predictor.get_input_names() + for name in input_names: + input_tensor = predictor.get_input_tensor(name) + output_names = predictor.get_output_names() + output_tensors = [] + for output_name in output_names: + output_tensor = predictor.get_output_tensor(output_name) + output_tensors.append(output_tensor) + return predictor, input_tensor, output_tensors + + def draw_text_det_res(dt_boxes, img_path): src_im = cv2.imread(img_path) for box in dt_boxes: