diff --git a/configs/det/ch_det_mv3_db.yml b/configs/det/ch_det_mv3_db.yml new file mode 100644 index 0000000000000000000000000000000000000000..275c71b97d21e6b168ef6aafae67eb2eb91c6f2b --- /dev/null +++ b/configs/det/ch_det_mv3_db.yml @@ -0,0 +1,134 @@ +Global: + use_gpu: true + epoch_num: 1200 + log_smooth_window: 20 + print_batch_step: 2 + save_model_dir: ./output/ch_db_mv3/ + save_epoch_step: 1200 + # evaluation is run every 5000 iterations after the 4000th iteration + eval_batch_step: [3000, 2000] + # if pretrained_model is saved in static mode, load_static_weights must set to True + load_static_weights: True + cal_metric_during_train: False + pretrained_model: ./pretrain_models/MobileNetV3_large_x0_5_pretrained + checkpoints: #./output/det_db_0.001_DiceLoss_256_pp_config_2.0b_4gpu/best_accuracy + save_inference_dir: + use_visualdl: False + infer_img: doc/imgs_en/img_10.jpg + save_res_path: ./output/det_db/predicts_db.txt + +Architecture: + model_type: det + algorithm: DB + Transform: + Backbone: + name: MobileNetV3 + scale: 0.5 + model_name: large + disable_se: True + Neck: + name: DBFPN + out_channels: 96 + Head: + name: DBHead + k: 50 + +Loss: + name: DBLoss + balance_loss: true + main_loss_type: DiceLoss + alpha: 5 + beta: 10 + ohem_ratio: 3 + +Optimizer: + name: Adam + beta1: 0.9 + beta2: 0.999 + lr: + name: Cosine + learning_rate: 0.001 + warmup_epoch: 2 + regularizer: + name: 'L2' + factor: 0 + +PostProcess: + name: DBPostProcess + thresh: 0.3 + box_thresh: 0.6 + max_candidates: 1000 + unclip_ratio: 1.5 + +Metric: + name: DetMetric + main_indicator: hmean + +Train: + dataset: + name: SimpleDataSet + data_dir: ./train_data/icdar2015/text_localization/ + label_file_list: + - ./train_data/icdar2015/text_localization/train_icdar2015_label.txt + ratio_list: [1.0] + transforms: + - DecodeImage: # load image + img_mode: BGR + channel_first: False + - DetLabelEncode: # Class handling label + - IaaAugment: + augmenter_args: + - { 'type': Fliplr, 'args': { 'p': 0.5 } } + - { 'type': Affine, 'args': { 'rotate': [-10, 10] } } + - { 'type': Resize, 'args': { 'size': [0.5, 3] } } + - EastRandomCropData: + size: [960, 960] + max_tries: 50 + keep_ratio: true + - MakeBorderMap: + shrink_ratio: 0.4 + thresh_min: 0.3 + thresh_max: 0.7 + - MakeShrinkMap: + shrink_ratio: 0.4 + min_text_size: 8 + - NormalizeImage: + scale: 1./255. + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: 'hwc' + - ToCHWImage: + - KeepKeys: + keep_keys: ['image', 'threshold_map', 'threshold_mask', 'shrink_map', 'shrink_mask'] # the order of the dataloader list + loader: + shuffle: True + drop_last: False + batch_size_per_card: 8 + num_workers: 4 + +Eval: + dataset: + name: SimpleDataSet + data_dir: ./train_data/icdar2015/text_localization/ + label_file_list: + - ./train_data/icdar2015/text_localization/test_icdar2015_label.txt + transforms: + - DecodeImage: # load image + img_mode: BGR + channel_first: False + - DetLabelEncode: # Class handling label + - DetResizeForTest: +# image_shape: [736, 1280] + - NormalizeImage: + scale: 1./255. + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: 'hwc' + - ToCHWImage: + - KeepKeys: + keep_keys: ['image', 'shape', 'polys', 'ignore_tags'] + loader: + shuffle: False + drop_last: False + batch_size_per_card: 1 # must be 1 + num_workers: 2 diff --git a/configs/det/ch_det_res18_db.yml b/configs/det/ch_det_res18_db.yml new file mode 100644 index 0000000000000000000000000000000000000000..9c903fa47a63c013348acdd8b977d31c728ff7d4 --- /dev/null +++ b/configs/det/ch_det_res18_db.yml @@ -0,0 +1,133 @@ +Global: + use_gpu: true + epoch_num: 1200 + log_smooth_window: 20 + print_batch_step: 2 + save_model_dir: ./output/ch_db_res18/ + save_epoch_step: 1200 + # evaluation is run every 5000 iterations after the 4000th iteration + eval_batch_step: [3000, 2000] + # if pretrained_model is saved in static mode, load_static_weights must set to True + load_static_weights: True + cal_metric_during_train: False + pretrained_model: ./pretrain_models/MobileNetV3_large_x0_5_pretrained + checkpoints: #./output/det_db_0.001_DiceLoss_256_pp_config_2.0b_4gpu/best_accuracy + save_inference_dir: + use_visualdl: False + infer_img: doc/imgs_en/img_10.jpg + save_res_path: ./output/det_db/predicts_db.txt + +Architecture: + model_type: det + algorithm: DB + Transform: + Backbone: + name: ResNet + layers: 18 + disable_se: True + Neck: + name: DBFPN + out_channels: 256 + Head: + name: DBHead + k: 50 + +Loss: + name: DBLoss + balance_loss: true + main_loss_type: DiceLoss + alpha: 5 + beta: 10 + ohem_ratio: 3 + +Optimizer: + name: Adam + beta1: 0.9 + beta2: 0.999 + lr: + name: Cosine + learning_rate: 0.001 + warmup_epoch: 2 + regularizer: + name: 'L2' + factor: 0 + +PostProcess: + name: DBPostProcess + thresh: 0.3 + box_thresh: 0.6 + max_candidates: 1000 + unclip_ratio: 1.5 + +Metric: + name: DetMetric + main_indicator: hmean + +Train: + dataset: + name: SimpleDataSet + data_dir: ./train_data/icdar2015/text_localization/ + label_file_list: + - ./train_data/icdar2015/text_localization/train_icdar2015_label.txt + ratio_list: [1.0] + transforms: + - DecodeImage: # load image + img_mode: BGR + channel_first: False + - DetLabelEncode: # Class handling label + - IaaAugment: + augmenter_args: + - { 'type': Fliplr, 'args': { 'p': 0.5 } } + - { 'type': Affine, 'args': { 'rotate': [-10, 10] } } + - { 'type': Resize, 'args': { 'size': [0.5, 3] } } + - EastRandomCropData: + size: [960, 960] + max_tries: 50 + keep_ratio: true + - MakeBorderMap: + shrink_ratio: 0.4 + thresh_min: 0.3 + thresh_max: 0.7 + - MakeShrinkMap: + shrink_ratio: 0.4 + min_text_size: 8 + - NormalizeImage: + scale: 1./255. + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: 'hwc' + - ToCHWImage: + - KeepKeys: + keep_keys: ['image', 'threshold_map', 'threshold_mask', 'shrink_map', 'shrink_mask'] # the order of the dataloader list + loader: + shuffle: True + drop_last: False + batch_size_per_card: 8 + num_workers: 4 + +Eval: + dataset: + name: SimpleDataSet + data_dir: ./train_data/icdar2015/text_localization/ + label_file_list: + - ./train_data/icdar2015/text_localization/test_icdar2015_label.txt + transforms: + - DecodeImage: # load image + img_mode: BGR + channel_first: False + - DetLabelEncode: # Class handling label + - DetResizeForTest: +# image_shape: [736, 1280] + - NormalizeImage: + scale: 1./255. + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: 'hwc' + - ToCHWImage: + - KeepKeys: + keep_keys: ['image', 'shape', 'polys', 'ignore_tags'] + loader: + shuffle: False + drop_last: False + batch_size_per_card: 1 # must be 1 + num_workers: 2 diff --git a/ppocr/data/imaug/operators.py b/ppocr/data/imaug/operators.py index 74b60de4258d7420847a9533f24b4d7da9306e17..927aa6407efceffc2dfb60410f33cc1addf004ac 100644 --- a/ppocr/data/imaug/operators.py +++ b/ppocr/data/imaug/operators.py @@ -42,6 +42,8 @@ class DecodeImage(object): img) > 0, "invalid input 'img' in DecodeImage" img = np.frombuffer(img, dtype='uint8') img = cv2.imdecode(img, 1) + if img is None: + return None if self.img_mode == 'GRAY': img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) elif self.img_mode == 'RGB': diff --git a/ppocr/data/simple_dataset.py b/ppocr/data/simple_dataset.py index 097da768aa4da2eb023c4a346fc30f0e704ab953..168913265b069c74ce23774699ff38f0d29a8589 100644 --- a/ppocr/data/simple_dataset.py +++ b/ppocr/data/simple_dataset.py @@ -27,7 +27,10 @@ class SimpleDataSet(Dataset): global_config = config['Global'] dataset_config = config[mode]['dataset'] loader_config = config[mode]['loader'] - batch_size = loader_config['batch_size_per_card'] + if 'data_num_per_epoch' in loader_config.keys(): + data_num_per_epoch = loader_config['data_num_per_epoch'] + else: + data_num_per_epoch = None self.delimiter = dataset_config.get('delimiter', '\t') label_file_list = dataset_config.pop('label_file_list') @@ -43,21 +46,34 @@ class SimpleDataSet(Dataset): self.do_shuffle = loader_config['shuffle'] logger.info("Initialize indexs of datasets:%s" % label_file_list) - self.data_lines = self.get_image_info_list(label_file_list, ratio_list) + self.data_lines = self.get_image_info_list(label_file_list, ratio_list, + data_num_per_epoch) self.data_idx_order_list = list(range(len(self.data_lines))) if mode.lower() == "train": self.shuffle_data_random() self.ops = create_operators(dataset_config['transforms'], global_config) - def get_image_info_list(self, file_list, ratio_list): + def _sample_dataset(self, datas, sample_ratio, data_num_per_epoch=None): + sample_num = round(len(datas) * sample_ratio) + + if data_num_per_epoch is not None: + sample_num = data_num_per_epoch * sample_ratio + + nums, rem = sample_num // len(datas), sample_num % len(datas) + return list(datas) * nums + random.sample(datas, rem) + + def get_image_info_list(self, + file_list, + ratio_list, + data_num_per_epoch=None): if isinstance(file_list, str): file_list = [file_list] data_lines = [] for idx, file in enumerate(file_list): with open(file, "rb") as f: lines = f.readlines() - lines = random.sample(lines, - round(len(lines) * ratio_list[idx])) + lines = self._sample_dataset(lines, ratio_list[idx], + data_num_per_epoch) data_lines.extend(lines) return data_lines @@ -76,6 +92,8 @@ class SimpleDataSet(Dataset): label = substr[1] img_path = os.path.join(self.data_dir, file_name) data = {'img_path': img_path, 'label': label} + if not os.path.exists(img_path): + raise Exception("{} does not exist!".format(img_path)) with open(data['img_path'], 'rb') as f: img = f.read() data['image'] = img diff --git a/ppocr/modeling/backbones/det_mobilenet_v3.py b/ppocr/modeling/backbones/det_mobilenet_v3.py index 017dce2f4ce624c83529492ea9050703814569d4..d6b453d1fe5441daeaaaadc8b86a7ba0de1c7326 100755 --- a/ppocr/modeling/backbones/det_mobilenet_v3.py +++ b/ppocr/modeling/backbones/det_mobilenet_v3.py @@ -34,13 +34,21 @@ def make_divisible(v, divisor=8, min_value=None): class MobileNetV3(nn.Layer): - def __init__(self, in_channels=3, model_name='large', scale=0.5, **kwargs): + def __init__(self, + in_channels=3, + model_name='large', + scale=0.5, + disable_se=False, + **kwargs): """ the MobilenetV3 backbone network for detection module. Args: params(dict): the super parameters for build network """ super(MobileNetV3, self).__init__() + + self.disable_se = disable_se + if model_name == "large": cfg = [ # k, exp, c, se, nl, s, @@ -223,7 +231,7 @@ class ResidualUnit(nn.Layer): if_act=True, act=act, name=name + "_depthwise") - if self.if_se: + if self.if_se and not self.disable_se: self.mid_se = SEModule(mid_channels, name=name + "_se") self.linear_conv = ConvBNLayer( in_channels=mid_channels, @@ -238,7 +246,7 @@ class ResidualUnit(nn.Layer): def forward(self, inputs): x = self.expand_conv(inputs) x = self.bottleneck_conv(x) - if self.if_se: + if self.if_se and not self.disable_se: x = self.mid_se(x) x = self.linear_conv(x) if self.if_shortcut: @@ -273,4 +281,4 @@ class SEModule(nn.Layer): outputs = F.relu(outputs) outputs = self.conv2(outputs) outputs = F.activation.hard_sigmoid(outputs) - return inputs * outputs \ No newline at end of file + return inputs * outputs diff --git a/ppocr/postprocess/db_postprocess.py b/ppocr/postprocess/db_postprocess.py index 316f7fc202c73dbb9a40dbd806f72e4506b991c5..dc27abd6ce6b94b52e9f9287c1c48f2d1c72b145 100644 --- a/ppocr/postprocess/db_postprocess.py +++ b/ppocr/postprocess/db_postprocess.py @@ -39,6 +39,7 @@ class DBPostProcess(object): self.max_candidates = max_candidates self.unclip_ratio = unclip_ratio self.min_size = 3 + self.dilation_kernel = np.array([[1, 1], [1, 1]]) def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height): ''' @@ -139,8 +140,11 @@ class DBPostProcess(object): boxes_batch = [] for batch_index in range(pred.shape[0]): height, width = shape_list[batch_index] - boxes, scores = self.boxes_from_bitmap( - pred[batch_index], segmentation[batch_index], width, height) + mask = cv2.dilate( + np.array(segmentation[batch_index]).astype(np.uint8), + self.dilation_kernel) + boxes, scores = self.boxes_from_bitmap(pred[batch_index], mask, + width, height) boxes_batch.append({'points': boxes}) - return boxes_batch \ No newline at end of file + return boxes_batch diff --git a/ppocr/utils/save_load.py b/ppocr/utils/save_load.py index 004322c832195fdbe5cbdaf20f4186ba7e9f8a26..af2de054de3656638fee8d4328765c21b4deaea4 100644 --- a/ppocr/utils/save_load.py +++ b/ppocr/utils/save_load.py @@ -55,8 +55,8 @@ def load_dygraph_pretrain(model, logger, path=None, load_static_weights=False): weight_name = weight_name.replace('binarize', '').replace( 'thresh', '') # for DB if weight_name in pre_state_dict.keys(): - logger.info('Load weight: {}, shape: {}'.format( - weight_name, pre_state_dict[weight_name].shape)) + # logger.info('Load weight: {}, shape: {}'.format( + # weight_name, pre_state_dict[weight_name].shape)) if 'encoder_rnn' in key: # delete axis which is 1 pre_state_dict[weight_name] = pre_state_dict[