diff --git a/configs/det/ch_ppocr_v2.0/ch_det_mv3_db_v2.0.yml b/configs/det/ch_ppocr_v2.0/ch_det_mv3_db_v2.0.yml new file mode 100644 index 0000000000000000000000000000000000000000..275c71b97d21e6b168ef6aafae67eb2eb91c6f2b --- /dev/null +++ b/configs/det/ch_ppocr_v2.0/ch_det_mv3_db_v2.0.yml @@ -0,0 +1,134 @@ +Global: + use_gpu: true + epoch_num: 1200 + log_smooth_window: 20 + print_batch_step: 2 + save_model_dir: ./output/ch_db_mv3/ + save_epoch_step: 1200 + # evaluation is run every 5000 iterations after the 4000th iteration + eval_batch_step: [3000, 2000] + # if pretrained_model is saved in static mode, load_static_weights must set to True + load_static_weights: True + cal_metric_during_train: False + pretrained_model: ./pretrain_models/MobileNetV3_large_x0_5_pretrained + checkpoints: #./output/det_db_0.001_DiceLoss_256_pp_config_2.0b_4gpu/best_accuracy + save_inference_dir: + use_visualdl: False + infer_img: doc/imgs_en/img_10.jpg + save_res_path: ./output/det_db/predicts_db.txt + +Architecture: + model_type: det + algorithm: DB + Transform: + Backbone: + name: MobileNetV3 + scale: 0.5 + model_name: large + disable_se: True + Neck: + name: DBFPN + out_channels: 96 + Head: + name: DBHead + k: 50 + +Loss: + name: DBLoss + balance_loss: true + main_loss_type: DiceLoss + alpha: 5 + beta: 10 + ohem_ratio: 3 + +Optimizer: + name: Adam + beta1: 0.9 + beta2: 0.999 + lr: + name: Cosine + learning_rate: 0.001 + warmup_epoch: 2 + regularizer: + name: 'L2' + factor: 0 + +PostProcess: + name: DBPostProcess + thresh: 0.3 + box_thresh: 0.6 + max_candidates: 1000 + unclip_ratio: 1.5 + +Metric: + name: DetMetric + main_indicator: hmean + +Train: + dataset: + name: SimpleDataSet + data_dir: ./train_data/icdar2015/text_localization/ + label_file_list: + - ./train_data/icdar2015/text_localization/train_icdar2015_label.txt + ratio_list: [1.0] + transforms: + - DecodeImage: # load image + img_mode: BGR + channel_first: False + - DetLabelEncode: # Class handling label + - IaaAugment: + augmenter_args: + - { 'type': Fliplr, 'args': { 'p': 0.5 } } + - { 'type': Affine, 'args': { 'rotate': [-10, 10] } } + - { 'type': Resize, 'args': { 'size': [0.5, 3] } } + - EastRandomCropData: + size: [960, 960] + max_tries: 50 + keep_ratio: true + - MakeBorderMap: + shrink_ratio: 0.4 + thresh_min: 0.3 + thresh_max: 0.7 + - MakeShrinkMap: + shrink_ratio: 0.4 + min_text_size: 8 + - NormalizeImage: + scale: 1./255. + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: 'hwc' + - ToCHWImage: + - KeepKeys: + keep_keys: ['image', 'threshold_map', 'threshold_mask', 'shrink_map', 'shrink_mask'] # the order of the dataloader list + loader: + shuffle: True + drop_last: False + batch_size_per_card: 8 + num_workers: 4 + +Eval: + dataset: + name: SimpleDataSet + data_dir: ./train_data/icdar2015/text_localization/ + label_file_list: + - ./train_data/icdar2015/text_localization/test_icdar2015_label.txt + transforms: + - DecodeImage: # load image + img_mode: BGR + channel_first: False + - DetLabelEncode: # Class handling label + - DetResizeForTest: +# image_shape: [736, 1280] + - NormalizeImage: + scale: 1./255. + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: 'hwc' + - ToCHWImage: + - KeepKeys: + keep_keys: ['image', 'shape', 'polys', 'ignore_tags'] + loader: + shuffle: False + drop_last: False + batch_size_per_card: 1 # must be 1 + num_workers: 2 diff --git a/configs/det/ch_ppocr_v2.0/ch_det_res18_db_v2.0.yml b/configs/det/ch_ppocr_v2.0/ch_det_res18_db_v2.0.yml new file mode 100644 index 0000000000000000000000000000000000000000..e34d94490fbfa81452a7ff507f6568e5ac29e3b7 --- /dev/null +++ b/configs/det/ch_ppocr_v2.0/ch_det_res18_db_v2.0.yml @@ -0,0 +1,133 @@ +Global: + use_gpu: true + epoch_num: 1200 + log_smooth_window: 20 + print_batch_step: 2 + save_model_dir: ./output/ch_db_res18/ + save_epoch_step: 1200 + # evaluation is run every 5000 iterations after the 4000th iteration + eval_batch_step: [3000, 2000] + # if pretrained_model is saved in static mode, load_static_weights must set to True + load_static_weights: True + cal_metric_during_train: False + pretrained_model: ./pretrain_models/ResNet18_vd_pretrained + checkpoints: #./output/det_db_0.001_DiceLoss_256_pp_config_2.0b_4gpu/best_accuracy + save_inference_dir: + use_visualdl: False + infer_img: doc/imgs_en/img_10.jpg + save_res_path: ./output/det_db/predicts_db.txt + +Architecture: + model_type: det + algorithm: DB + Transform: + Backbone: + name: ResNet + layers: 18 + disable_se: True + Neck: + name: DBFPN + out_channels: 256 + Head: + name: DBHead + k: 50 + +Loss: + name: DBLoss + balance_loss: true + main_loss_type: DiceLoss + alpha: 5 + beta: 10 + ohem_ratio: 3 + +Optimizer: + name: Adam + beta1: 0.9 + beta2: 0.999 + lr: + name: Cosine + learning_rate: 0.001 + warmup_epoch: 2 + regularizer: + name: 'L2' + factor: 0 + +PostProcess: + name: DBPostProcess + thresh: 0.3 + box_thresh: 0.6 + max_candidates: 1000 + unclip_ratio: 1.5 + +Metric: + name: DetMetric + main_indicator: hmean + +Train: + dataset: + name: SimpleDataSet + data_dir: ./train_data/icdar2015/text_localization/ + label_file_list: + - ./train_data/icdar2015/text_localization/train_icdar2015_label.txt + ratio_list: [1.0] + transforms: + - DecodeImage: # load image + img_mode: BGR + channel_first: False + - DetLabelEncode: # Class handling label + - IaaAugment: + augmenter_args: + - { 'type': Fliplr, 'args': { 'p': 0.5 } } + - { 'type': Affine, 'args': { 'rotate': [-10, 10] } } + - { 'type': Resize, 'args': { 'size': [0.5, 3] } } + - EastRandomCropData: + size: [960, 960] + max_tries: 50 + keep_ratio: true + - MakeBorderMap: + shrink_ratio: 0.4 + thresh_min: 0.3 + thresh_max: 0.7 + - MakeShrinkMap: + shrink_ratio: 0.4 + min_text_size: 8 + - NormalizeImage: + scale: 1./255. + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: 'hwc' + - ToCHWImage: + - KeepKeys: + keep_keys: ['image', 'threshold_map', 'threshold_mask', 'shrink_map', 'shrink_mask'] # the order of the dataloader list + loader: + shuffle: True + drop_last: False + batch_size_per_card: 8 + num_workers: 4 + +Eval: + dataset: + name: SimpleDataSet + data_dir: ./train_data/icdar2015/text_localization/ + label_file_list: + - ./train_data/icdar2015/text_localization/test_icdar2015_label.txt + transforms: + - DecodeImage: # load image + img_mode: BGR + channel_first: False + - DetLabelEncode: # Class handling label + - DetResizeForTest: +# image_shape: [736, 1280] + - NormalizeImage: + scale: 1./255. + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: 'hwc' + - ToCHWImage: + - KeepKeys: + keep_keys: ['image', 'shape', 'polys', 'ignore_tags'] + loader: + shuffle: False + drop_last: False + batch_size_per_card: 1 # must be 1 + num_workers: 2 diff --git a/ppocr/data/imaug/operators.py b/ppocr/data/imaug/operators.py index c13529276b37b5d8059171aa70e36f0acc99ca04..57cd3b4b172cc52d073999a007fe5a9a343c9186 100644 --- a/ppocr/data/imaug/operators.py +++ b/ppocr/data/imaug/operators.py @@ -42,6 +42,8 @@ class DecodeImage(object): img) > 0, "invalid input 'img' in DecodeImage" img = np.frombuffer(img, dtype='uint8') img = cv2.imdecode(img, 1) + if img is None: + return None if self.img_mode == 'GRAY': img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) elif self.img_mode == 'RGB': @@ -197,7 +199,7 @@ class DetResizeForTest(object): ratio_w = resize_w / float(w) # return img, np.array([h, w]) return img, [ratio_h, ratio_w] - + def resize_image_type2(self, img): h, w, _ = img.shape diff --git a/ppocr/data/simple_dataset.py b/ppocr/data/simple_dataset.py index 097da768aa4da2eb023c4a346fc30f0e704ab953..ab17dd1a3816a86f92707e2812881c892ac59ae6 100644 --- a/ppocr/data/simple_dataset.py +++ b/ppocr/data/simple_dataset.py @@ -27,14 +27,13 @@ class SimpleDataSet(Dataset): global_config = config['Global'] dataset_config = config[mode]['dataset'] loader_config = config[mode]['loader'] - batch_size = loader_config['batch_size_per_card'] self.delimiter = dataset_config.get('delimiter', '\t') label_file_list = dataset_config.pop('label_file_list') data_source_num = len(label_file_list) ratio_list = dataset_config.get("ratio_list", [1.0]) if isinstance(ratio_list, (float, int)): - ratio_list = [float(ratio_list)] * len(data_source_num) + ratio_list = [float(ratio_list)] * int(data_source_num) assert len( ratio_list @@ -76,6 +75,8 @@ class SimpleDataSet(Dataset): label = substr[1] img_path = os.path.join(self.data_dir, file_name) data = {'img_path': img_path, 'label': label} + if not os.path.exists(img_path): + raise Exception("{} does not exist!".format(img_path)) with open(data['img_path'], 'rb') as f: img = f.read() data['image'] = img diff --git a/ppocr/modeling/architectures/base_model.py b/ppocr/modeling/architectures/base_model.py index 0c4fe6503f6e6401581bf92dc9b3421f40722c7c..ab44b53a2bf8214b63954aa867b67a3ed1e05fab 100644 --- a/ppocr/modeling/architectures/base_model.py +++ b/ppocr/modeling/architectures/base_model.py @@ -16,7 +16,7 @@ from __future__ import division from __future__ import print_function from paddle import nn -from ppocr.modeling.transform import build_transform +from ppocr.modeling.transforms import build_transform from ppocr.modeling.backbones import build_backbone from ppocr.modeling.necks import build_neck from ppocr.modeling.heads import build_head diff --git a/ppocr/modeling/backbones/det_mobilenet_v3.py b/ppocr/modeling/backbones/det_mobilenet_v3.py index 017dce2f4ce624c83529492ea9050703814569d4..f97bcfca472310cfd681025595b7bce57d1ccf85 100755 --- a/ppocr/modeling/backbones/det_mobilenet_v3.py +++ b/ppocr/modeling/backbones/det_mobilenet_v3.py @@ -34,13 +34,21 @@ def make_divisible(v, divisor=8, min_value=None): class MobileNetV3(nn.Layer): - def __init__(self, in_channels=3, model_name='large', scale=0.5, **kwargs): + def __init__(self, + in_channels=3, + model_name='large', + scale=0.5, + disable_se=False, + **kwargs): """ the MobilenetV3 backbone network for detection module. Args: params(dict): the super parameters for build network """ super(MobileNetV3, self).__init__() + + self.disable_se = disable_se + if model_name == "large": cfg = [ # k, exp, c, se, nl, s, @@ -103,6 +111,7 @@ class MobileNetV3(nn.Layer): i = 0 inplanes = make_divisible(inplanes * scale) for (k, exp, c, se, nl, s) in cfg: + se = se and not self.disable_se if s == 2 and i > 2: self.out_channels.append(inplanes) self.stages.append(nn.Sequential(*block_list)) @@ -273,4 +282,4 @@ class SEModule(nn.Layer): outputs = F.relu(outputs) outputs = self.conv2(outputs) outputs = F.activation.hard_sigmoid(outputs) - return inputs * outputs \ No newline at end of file + return inputs * outputs diff --git a/ppocr/modeling/transform/__init__.py b/ppocr/modeling/transforms/__init__.py similarity index 100% rename from ppocr/modeling/transform/__init__.py rename to ppocr/modeling/transforms/__init__.py diff --git a/ppocr/modeling/transform/tps.py b/ppocr/modeling/transforms/tps.py similarity index 100% rename from ppocr/modeling/transform/tps.py rename to ppocr/modeling/transforms/tps.py diff --git a/ppocr/postprocess/db_postprocess.py b/ppocr/postprocess/db_postprocess.py index 71b8aba7c78198dd018d9bc9d47e477ac0b91cd2..0be2c12ad4bb9e70708c45d3e3f60dd526dc4e83 100644 --- a/ppocr/postprocess/db_postprocess.py +++ b/ppocr/postprocess/db_postprocess.py @@ -33,12 +33,14 @@ class DBPostProcess(object): box_thresh=0.7, max_candidates=1000, unclip_ratio=2.0, + use_dilation=False, **kwargs): self.thresh = thresh self.box_thresh = box_thresh self.max_candidates = max_candidates self.unclip_ratio = unclip_ratio self.min_size = 3 + self.dilation_kernel = None if not use_dilation else [[1, 1], [1, 1]] def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height): ''' @@ -139,8 +141,14 @@ class DBPostProcess(object): boxes_batch = [] for batch_index in range(pred.shape[0]): src_h, src_w, ratio_h, ratio_w = shape_list[batch_index] - boxes, scores = self.boxes_from_bitmap( - pred[batch_index], segmentation[batch_index], src_w, src_h) + if self.dilation_kernel is not None: + mask = cv2.dilate( + np.array(segmentation[batch_index]).astype(np.uint8), + self.dilation_kernel) + else: + mask = segmentation[batch_index] + boxes, scores = self.boxes_from_bitmap(pred[batch_index], mask, + src_w, src_h) boxes_batch.append({'points': boxes}) - return boxes_batch \ No newline at end of file + return boxes_batch diff --git a/ppocr/utils/save_load.py b/ppocr/utils/save_load.py index 004322c832195fdbe5cbdaf20f4186ba7e9f8a26..af2de054de3656638fee8d4328765c21b4deaea4 100644 --- a/ppocr/utils/save_load.py +++ b/ppocr/utils/save_load.py @@ -55,8 +55,8 @@ def load_dygraph_pretrain(model, logger, path=None, load_static_weights=False): weight_name = weight_name.replace('binarize', '').replace( 'thresh', '') # for DB if weight_name in pre_state_dict.keys(): - logger.info('Load weight: {}, shape: {}'.format( - weight_name, pre_state_dict[weight_name].shape)) + # logger.info('Load weight: {}, shape: {}'.format( + # weight_name, pre_state_dict[weight_name].shape)) if 'encoder_rnn' in key: # delete axis which is 1 pre_state_dict[weight_name] = pre_state_dict[