提交 5f2f08a0 编写于 作者: L LDOUBLEV

add ppocr_v2 ch_db

上级 a948584c
Global:
use_gpu: true
epoch_num: 1200
log_smooth_window: 20
print_batch_step: 2
save_model_dir: ./output/ch_db_mv3/
save_epoch_step: 1200
# evaluation is run every 5000 iterations after the 4000th iteration
eval_batch_step: [3000, 2000]
# if pretrained_model is saved in static mode, load_static_weights must set to True
load_static_weights: True
cal_metric_during_train: False
pretrained_model: ./pretrain_models/MobileNetV3_large_x0_5_pretrained
checkpoints: #./output/det_db_0.001_DiceLoss_256_pp_config_2.0b_4gpu/best_accuracy
save_inference_dir:
use_visualdl: False
infer_img: doc/imgs_en/img_10.jpg
save_res_path: ./output/det_db/predicts_db.txt
Architecture:
model_type: det
algorithm: DB
Transform:
Backbone:
name: MobileNetV3
scale: 0.5
model_name: large
disable_se: True
Neck:
name: DBFPN
out_channels: 96
Head:
name: DBHead
k: 50
Loss:
name: DBLoss
balance_loss: true
main_loss_type: DiceLoss
alpha: 5
beta: 10
ohem_ratio: 3
Optimizer:
name: Adam
beta1: 0.9
beta2: 0.999
lr:
name: Cosine
learning_rate: 0.001
warmup_epoch: 2
regularizer:
name: 'L2'
factor: 0
PostProcess:
name: DBPostProcess
thresh: 0.3
box_thresh: 0.6
max_candidates: 1000
unclip_ratio: 1.5
Metric:
name: DetMetric
main_indicator: hmean
Train:
dataset:
name: SimpleDataSet
data_dir: ./train_data/icdar2015/text_localization/
label_file_list:
- ./train_data/icdar2015/text_localization/train_icdar2015_label.txt
ratio_list: [1.0]
transforms:
- DecodeImage: # load image
img_mode: BGR
channel_first: False
- DetLabelEncode: # Class handling label
- IaaAugment:
augmenter_args:
- { 'type': Fliplr, 'args': { 'p': 0.5 } }
- { 'type': Affine, 'args': { 'rotate': [-10, 10] } }
- { 'type': Resize, 'args': { 'size': [0.5, 3] } }
- EastRandomCropData:
size: [960, 960]
max_tries: 50
keep_ratio: true
- MakeBorderMap:
shrink_ratio: 0.4
thresh_min: 0.3
thresh_max: 0.7
- MakeShrinkMap:
shrink_ratio: 0.4
min_text_size: 8
- NormalizeImage:
scale: 1./255.
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: 'hwc'
- ToCHWImage:
- KeepKeys:
keep_keys: ['image', 'threshold_map', 'threshold_mask', 'shrink_map', 'shrink_mask'] # the order of the dataloader list
loader:
shuffle: True
drop_last: False
batch_size_per_card: 8
num_workers: 4
Eval:
dataset:
name: SimpleDataSet
data_dir: ./train_data/icdar2015/text_localization/
label_file_list:
- ./train_data/icdar2015/text_localization/test_icdar2015_label.txt
transforms:
- DecodeImage: # load image
img_mode: BGR
channel_first: False
- DetLabelEncode: # Class handling label
- DetResizeForTest:
# image_shape: [736, 1280]
- NormalizeImage:
scale: 1./255.
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: 'hwc'
- ToCHWImage:
- KeepKeys:
keep_keys: ['image', 'shape', 'polys', 'ignore_tags']
loader:
shuffle: False
drop_last: False
batch_size_per_card: 1 # must be 1
num_workers: 2
Global:
use_gpu: true
epoch_num: 1200
log_smooth_window: 20
print_batch_step: 2
save_model_dir: ./output/ch_db_res18/
save_epoch_step: 1200
# evaluation is run every 5000 iterations after the 4000th iteration
eval_batch_step: [3000, 2000]
# if pretrained_model is saved in static mode, load_static_weights must set to True
load_static_weights: True
cal_metric_during_train: False
pretrained_model: ./pretrain_models/MobileNetV3_large_x0_5_pretrained
checkpoints: #./output/det_db_0.001_DiceLoss_256_pp_config_2.0b_4gpu/best_accuracy
save_inference_dir:
use_visualdl: False
infer_img: doc/imgs_en/img_10.jpg
save_res_path: ./output/det_db/predicts_db.txt
Architecture:
model_type: det
algorithm: DB
Transform:
Backbone:
name: ResNet
layers: 18
disable_se: True
Neck:
name: DBFPN
out_channels: 256
Head:
name: DBHead
k: 50
Loss:
name: DBLoss
balance_loss: true
main_loss_type: DiceLoss
alpha: 5
beta: 10
ohem_ratio: 3
Optimizer:
name: Adam
beta1: 0.9
beta2: 0.999
lr:
name: Cosine
learning_rate: 0.001
warmup_epoch: 2
regularizer:
name: 'L2'
factor: 0
PostProcess:
name: DBPostProcess
thresh: 0.3
box_thresh: 0.6
max_candidates: 1000
unclip_ratio: 1.5
Metric:
name: DetMetric
main_indicator: hmean
Train:
dataset:
name: SimpleDataSet
data_dir: ./train_data/icdar2015/text_localization/
label_file_list:
- ./train_data/icdar2015/text_localization/train_icdar2015_label.txt
ratio_list: [1.0]
transforms:
- DecodeImage: # load image
img_mode: BGR
channel_first: False
- DetLabelEncode: # Class handling label
- IaaAugment:
augmenter_args:
- { 'type': Fliplr, 'args': { 'p': 0.5 } }
- { 'type': Affine, 'args': { 'rotate': [-10, 10] } }
- { 'type': Resize, 'args': { 'size': [0.5, 3] } }
- EastRandomCropData:
size: [960, 960]
max_tries: 50
keep_ratio: true
- MakeBorderMap:
shrink_ratio: 0.4
thresh_min: 0.3
thresh_max: 0.7
- MakeShrinkMap:
shrink_ratio: 0.4
min_text_size: 8
- NormalizeImage:
scale: 1./255.
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: 'hwc'
- ToCHWImage:
- KeepKeys:
keep_keys: ['image', 'threshold_map', 'threshold_mask', 'shrink_map', 'shrink_mask'] # the order of the dataloader list
loader:
shuffle: True
drop_last: False
batch_size_per_card: 8
num_workers: 4
Eval:
dataset:
name: SimpleDataSet
data_dir: ./train_data/icdar2015/text_localization/
label_file_list:
- ./train_data/icdar2015/text_localization/test_icdar2015_label.txt
transforms:
- DecodeImage: # load image
img_mode: BGR
channel_first: False
- DetLabelEncode: # Class handling label
- DetResizeForTest:
# image_shape: [736, 1280]
- NormalizeImage:
scale: 1./255.
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: 'hwc'
- ToCHWImage:
- KeepKeys:
keep_keys: ['image', 'shape', 'polys', 'ignore_tags']
loader:
shuffle: False
drop_last: False
batch_size_per_card: 1 # must be 1
num_workers: 2
...@@ -42,6 +42,8 @@ class DecodeImage(object): ...@@ -42,6 +42,8 @@ class DecodeImage(object):
img) > 0, "invalid input 'img' in DecodeImage" img) > 0, "invalid input 'img' in DecodeImage"
img = np.frombuffer(img, dtype='uint8') img = np.frombuffer(img, dtype='uint8')
img = cv2.imdecode(img, 1) img = cv2.imdecode(img, 1)
if img is None:
return None
if self.img_mode == 'GRAY': if self.img_mode == 'GRAY':
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
elif self.img_mode == 'RGB': elif self.img_mode == 'RGB':
......
...@@ -27,7 +27,10 @@ class SimpleDataSet(Dataset): ...@@ -27,7 +27,10 @@ class SimpleDataSet(Dataset):
global_config = config['Global'] global_config = config['Global']
dataset_config = config[mode]['dataset'] dataset_config = config[mode]['dataset']
loader_config = config[mode]['loader'] loader_config = config[mode]['loader']
batch_size = loader_config['batch_size_per_card'] if 'data_num_per_epoch' in loader_config.keys():
data_num_per_epoch = loader_config['data_num_per_epoch']
else:
data_num_per_epoch = None
self.delimiter = dataset_config.get('delimiter', '\t') self.delimiter = dataset_config.get('delimiter', '\t')
label_file_list = dataset_config.pop('label_file_list') label_file_list = dataset_config.pop('label_file_list')
...@@ -43,21 +46,34 @@ class SimpleDataSet(Dataset): ...@@ -43,21 +46,34 @@ class SimpleDataSet(Dataset):
self.do_shuffle = loader_config['shuffle'] self.do_shuffle = loader_config['shuffle']
logger.info("Initialize indexs of datasets:%s" % label_file_list) logger.info("Initialize indexs of datasets:%s" % label_file_list)
self.data_lines = self.get_image_info_list(label_file_list, ratio_list) self.data_lines = self.get_image_info_list(label_file_list, ratio_list,
data_num_per_epoch)
self.data_idx_order_list = list(range(len(self.data_lines))) self.data_idx_order_list = list(range(len(self.data_lines)))
if mode.lower() == "train": if mode.lower() == "train":
self.shuffle_data_random() self.shuffle_data_random()
self.ops = create_operators(dataset_config['transforms'], global_config) self.ops = create_operators(dataset_config['transforms'], global_config)
def get_image_info_list(self, file_list, ratio_list): def _sample_dataset(self, datas, sample_ratio, data_num_per_epoch=None):
sample_num = round(len(datas) * sample_ratio)
if data_num_per_epoch is not None:
sample_num = data_num_per_epoch * sample_ratio
nums, rem = sample_num // len(datas), sample_num % len(datas)
return list(datas) * nums + random.sample(datas, rem)
def get_image_info_list(self,
file_list,
ratio_list,
data_num_per_epoch=None):
if isinstance(file_list, str): if isinstance(file_list, str):
file_list = [file_list] file_list = [file_list]
data_lines = [] data_lines = []
for idx, file in enumerate(file_list): for idx, file in enumerate(file_list):
with open(file, "rb") as f: with open(file, "rb") as f:
lines = f.readlines() lines = f.readlines()
lines = random.sample(lines, lines = self._sample_dataset(lines, ratio_list[idx],
round(len(lines) * ratio_list[idx])) data_num_per_epoch)
data_lines.extend(lines) data_lines.extend(lines)
return data_lines return data_lines
...@@ -76,6 +92,8 @@ class SimpleDataSet(Dataset): ...@@ -76,6 +92,8 @@ class SimpleDataSet(Dataset):
label = substr[1] label = substr[1]
img_path = os.path.join(self.data_dir, file_name) img_path = os.path.join(self.data_dir, file_name)
data = {'img_path': img_path, 'label': label} data = {'img_path': img_path, 'label': label}
if not os.path.exists(img_path):
raise Exception("{} does not exist!".format(img_path))
with open(data['img_path'], 'rb') as f: with open(data['img_path'], 'rb') as f:
img = f.read() img = f.read()
data['image'] = img data['image'] = img
......
...@@ -34,13 +34,21 @@ def make_divisible(v, divisor=8, min_value=None): ...@@ -34,13 +34,21 @@ def make_divisible(v, divisor=8, min_value=None):
class MobileNetV3(nn.Layer): class MobileNetV3(nn.Layer):
def __init__(self, in_channels=3, model_name='large', scale=0.5, **kwargs): def __init__(self,
in_channels=3,
model_name='large',
scale=0.5,
disable_se=False,
**kwargs):
""" """
the MobilenetV3 backbone network for detection module. the MobilenetV3 backbone network for detection module.
Args: Args:
params(dict): the super parameters for build network params(dict): the super parameters for build network
""" """
super(MobileNetV3, self).__init__() super(MobileNetV3, self).__init__()
self.disable_se = disable_se
if model_name == "large": if model_name == "large":
cfg = [ cfg = [
# k, exp, c, se, nl, s, # k, exp, c, se, nl, s,
...@@ -223,7 +231,7 @@ class ResidualUnit(nn.Layer): ...@@ -223,7 +231,7 @@ class ResidualUnit(nn.Layer):
if_act=True, if_act=True,
act=act, act=act,
name=name + "_depthwise") name=name + "_depthwise")
if self.if_se: if self.if_se and not self.disable_se:
self.mid_se = SEModule(mid_channels, name=name + "_se") self.mid_se = SEModule(mid_channels, name=name + "_se")
self.linear_conv = ConvBNLayer( self.linear_conv = ConvBNLayer(
in_channels=mid_channels, in_channels=mid_channels,
...@@ -238,7 +246,7 @@ class ResidualUnit(nn.Layer): ...@@ -238,7 +246,7 @@ class ResidualUnit(nn.Layer):
def forward(self, inputs): def forward(self, inputs):
x = self.expand_conv(inputs) x = self.expand_conv(inputs)
x = self.bottleneck_conv(x) x = self.bottleneck_conv(x)
if self.if_se: if self.if_se and not self.disable_se:
x = self.mid_se(x) x = self.mid_se(x)
x = self.linear_conv(x) x = self.linear_conv(x)
if self.if_shortcut: if self.if_shortcut:
......
...@@ -39,6 +39,7 @@ class DBPostProcess(object): ...@@ -39,6 +39,7 @@ class DBPostProcess(object):
self.max_candidates = max_candidates self.max_candidates = max_candidates
self.unclip_ratio = unclip_ratio self.unclip_ratio = unclip_ratio
self.min_size = 3 self.min_size = 3
self.dilation_kernel = np.array([[1, 1], [1, 1]])
def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height): def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
''' '''
...@@ -139,8 +140,11 @@ class DBPostProcess(object): ...@@ -139,8 +140,11 @@ class DBPostProcess(object):
boxes_batch = [] boxes_batch = []
for batch_index in range(pred.shape[0]): for batch_index in range(pred.shape[0]):
height, width = shape_list[batch_index] height, width = shape_list[batch_index]
boxes, scores = self.boxes_from_bitmap( mask = cv2.dilate(
pred[batch_index], segmentation[batch_index], width, height) np.array(segmentation[batch_index]).astype(np.uint8),
self.dilation_kernel)
boxes, scores = self.boxes_from_bitmap(pred[batch_index], mask,
width, height)
boxes_batch.append({'points': boxes}) boxes_batch.append({'points': boxes})
return boxes_batch return boxes_batch
...@@ -55,8 +55,8 @@ def load_dygraph_pretrain(model, logger, path=None, load_static_weights=False): ...@@ -55,8 +55,8 @@ def load_dygraph_pretrain(model, logger, path=None, load_static_weights=False):
weight_name = weight_name.replace('binarize', '').replace( weight_name = weight_name.replace('binarize', '').replace(
'thresh', '') # for DB 'thresh', '') # for DB
if weight_name in pre_state_dict.keys(): if weight_name in pre_state_dict.keys():
logger.info('Load weight: {}, shape: {}'.format( # logger.info('Load weight: {}, shape: {}'.format(
weight_name, pre_state_dict[weight_name].shape)) # weight_name, pre_state_dict[weight_name].shape))
if 'encoder_rnn' in key: if 'encoder_rnn' in key:
# delete axis which is 1 # delete axis which is 1
pre_state_dict[weight_name] = pre_state_dict[ pre_state_dict[weight_name] = pre_state_dict[
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册