未验证 提交 4cddec73 编写于 作者: littletomatodonkey's avatar littletomatodonkey 提交者: GitHub

add rotnet code (#6065)

* add rotnet code

* add config

* fix infer for ssl

* rm unused code
上级 3c9200c6
Global:
debug: false
use_gpu: true
epoch_num: 100
log_smooth_window: 20
print_batch_step: 10
save_model_dir: ./output/rec_ppocr_v3_rotnet
save_epoch_step: 3
eval_batch_step: [0, 2000]
cal_metric_during_train: true
pretrained_model: null
checkpoints: null
save_inference_dir: null
use_visualdl: false
infer_img: doc/imgs_words/ch/word_1.jpg
character_dict_path: ppocr/utils/ppocr_keys_v1.txt
max_text_length: 25
infer_mode: false
use_space_char: true
save_res_path: ./output/rec/predicts_chinese_lite_v2.0.txt
Optimizer:
name: Adam
beta1: 0.9
beta2: 0.999
lr:
name: Cosine
learning_rate: 0.001
regularizer:
name: L2
factor: 1.0e-05
Architecture:
model_type: cls
algorithm: CLS
Transform: null
Backbone:
name: MobileNetV1Enhance
scale: 0.5
last_conv_stride: [1, 2]
last_pool_type: avg
Neck:
Head:
name: ClsHead
class_dim: 4
Loss:
name: ClsLoss
main_indicator: acc
PostProcess:
name: ClsPostProcess
Metric:
name: ClsMetric
main_indicator: acc
Train:
dataset:
name: SimpleDataSet
data_dir: ./train_data
label_file_list:
- ./train_data/train_list.txt
transforms:
- DecodeImage:
img_mode: BGR
channel_first: false
- RecAug:
use_tia: False
- RandAugment:
- SSLRotateResize:
image_shape: [3, 48, 320]
- KeepKeys:
keep_keys: ["image", "label"]
loader:
collate_fn: "SSLRotateCollate"
shuffle: true
batch_size_per_card: 32
drop_last: true
num_workers: 8
Eval:
dataset:
name: SimpleDataSet
data_dir: ./train_data
label_file_list:
- ./train_data/val_list.txt
transforms:
- DecodeImage:
img_mode: BGR
channel_first: false
- SSLRotateResize:
image_shape: [3, 48, 320]
- KeepKeys:
keep_keys: ["image", "label"]
loader:
collate_fn: "SSLRotateCollate"
shuffle: false
drop_last: false
batch_size_per_card: 64
num_workers: 8
profiler_options: null
...@@ -35,17 +35,7 @@ from ppocr.metrics import build_metric ...@@ -35,17 +35,7 @@ from ppocr.metrics import build_metric
import tools.program as program import tools.program as program
from paddleslim.dygraph.quant import QAT from paddleslim.dygraph.quant import QAT
from ppocr.data import build_dataloader from ppocr.data import build_dataloader
from tools.export_model import export_single_model
def export_single_model(quanter, model, infer_shape, save_path, logger):
quanter.save_quantized_model(
model,
save_path,
input_spec=[
paddle.static.InputSpec(
shape=[None] + infer_shape, dtype='float32')
])
logger.info('inference QAT model is saved to {}'.format(save_path))
def main(): def main():
...@@ -84,17 +74,54 @@ def main(): ...@@ -84,17 +74,54 @@ def main():
config['Global']) config['Global'])
# build model # build model
# for rec algorithm
if hasattr(post_process_class, 'character'): if hasattr(post_process_class, 'character'):
char_num = len(getattr(post_process_class, 'character')) char_num = len(getattr(post_process_class, 'character'))
if config['Architecture']["algorithm"] in ["Distillation", if config['Architecture']["algorithm"] in ["Distillation",
]: # distillation model ]: # distillation model
for key in config['Architecture']["Models"]: for key in config['Architecture']["Models"]:
config['Architecture']["Models"][key]["Head"][ if config['Architecture']['Models'][key]['Head'][
'out_channels'] = char_num 'name'] == 'MultiHead': # for multi head
if config['PostProcess'][
'name'] == 'DistillationSARLabelDecode':
char_num = char_num - 2
# update SARLoss params
assert list(config['Loss']['loss_config_list'][-1].keys())[
0] == 'DistillationSARLoss'
config['Loss']['loss_config_list'][-1][
'DistillationSARLoss']['ignore_index'] = char_num + 1
out_channels_list = {}
out_channels_list['CTCLabelDecode'] = char_num
out_channels_list['SARLabelDecode'] = char_num + 2
config['Architecture']['Models'][key]['Head'][
'out_channels_list'] = out_channels_list
else:
config['Architecture']["Models"][key]["Head"][
'out_channels'] = char_num
elif config['Architecture']['Head'][
'name'] == 'MultiHead': # for multi head
if config['PostProcess']['name'] == 'SARLabelDecode':
char_num = char_num - 2
# update SARLoss params
assert list(config['Loss']['loss_config_list'][1].keys())[
0] == 'SARLoss'
if config['Loss']['loss_config_list'][1]['SARLoss'] is None:
config['Loss']['loss_config_list'][1]['SARLoss'] = {
'ignore_index': char_num + 1
}
else:
config['Loss']['loss_config_list'][1]['SARLoss'][
'ignore_index'] = char_num + 1
out_channels_list = {}
out_channels_list['CTCLabelDecode'] = char_num
out_channels_list['SARLabelDecode'] = char_num + 2
config['Architecture']['Head'][
'out_channels_list'] = out_channels_list
else: # base rec model else: # base rec model
config['Architecture']["Head"]['out_channels'] = char_num config['Architecture']["Head"]['out_channels'] = char_num
if config['PostProcess']['name'] == 'SARLabelDecode': # for SAR model
config['Loss']['ignore_index'] = char_num - 1
model = build_model(config['Architecture']) model = build_model(config['Architecture'])
# get QAT model # get QAT model
...@@ -120,21 +147,22 @@ def main(): ...@@ -120,21 +147,22 @@ def main():
for k, v in metric.items(): for k, v in metric.items():
logger.info('{}:{}'.format(k, v)) logger.info('{}:{}'.format(k, v))
infer_shape = [3, 32, 100] if model_type == "rec" else [3, 640, 640]
save_path = config["Global"]["save_inference_dir"] save_path = config["Global"]["save_inference_dir"]
arch_config = config["Architecture"] arch_config = config["Architecture"]
arch_config = config["Architecture"]
if arch_config["algorithm"] in ["Distillation", ]: # distillation model if arch_config["algorithm"] in ["Distillation", ]: # distillation model
archs = list(arch_config["Models"].values())
for idx, name in enumerate(model.model_name_list): for idx, name in enumerate(model.model_name_list):
model.model_list[idx].eval() model.model_list[idx].eval()
sub_model_save_path = os.path.join(save_path, name, "inference") sub_model_save_path = os.path.join(save_path, name, "inference")
export_single_model(quanter, model.model_list[idx], infer_shape, export_single_model(model.model_list[idx], archs[idx],
sub_model_save_path, logger) sub_model_save_path, logger, quanter)
else: else:
save_path = os.path.join(save_path, "inference") save_path = os.path.join(save_path, "inference")
model.eval() export_single_model(model, arch_config, save_path, logger, quanter)
export_single_model(quanter, model, infer_shape, save_path, logger)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -112,10 +112,48 @@ def main(config, device, logger, vdl_writer): ...@@ -112,10 +112,48 @@ def main(config, device, logger, vdl_writer):
if config['Architecture']["algorithm"] in ["Distillation", if config['Architecture']["algorithm"] in ["Distillation",
]: # distillation model ]: # distillation model
for key in config['Architecture']["Models"]: for key in config['Architecture']["Models"]:
config['Architecture']["Models"][key]["Head"][ if config['Architecture']['Models'][key]['Head'][
'out_channels'] = char_num 'name'] == 'MultiHead': # for multi head
if config['PostProcess'][
'name'] == 'DistillationSARLabelDecode':
char_num = char_num - 2
# update SARLoss params
assert list(config['Loss']['loss_config_list'][-1].keys())[
0] == 'DistillationSARLoss'
config['Loss']['loss_config_list'][-1][
'DistillationSARLoss']['ignore_index'] = char_num + 1
out_channels_list = {}
out_channels_list['CTCLabelDecode'] = char_num
out_channels_list['SARLabelDecode'] = char_num + 2
config['Architecture']['Models'][key]['Head'][
'out_channels_list'] = out_channels_list
else:
config['Architecture']["Models"][key]["Head"][
'out_channels'] = char_num
elif config['Architecture']['Head'][
'name'] == 'MultiHead': # for multi head
if config['PostProcess']['name'] == 'SARLabelDecode':
char_num = char_num - 2
# update SARLoss params
assert list(config['Loss']['loss_config_list'][1].keys())[
0] == 'SARLoss'
if config['Loss']['loss_config_list'][1]['SARLoss'] is None:
config['Loss']['loss_config_list'][1]['SARLoss'] = {
'ignore_index': char_num + 1
}
else:
config['Loss']['loss_config_list'][1]['SARLoss'][
'ignore_index'] = char_num + 1
out_channels_list = {}
out_channels_list['CTCLabelDecode'] = char_num
out_channels_list['SARLabelDecode'] = char_num + 2
config['Architecture']['Head'][
'out_channels_list'] = out_channels_list
else: # base rec model else: # base rec model
config['Architecture']["Head"]['out_channels'] = char_num config['Architecture']["Head"]['out_channels'] = char_num
if config['PostProcess']['name'] == 'SARLabelDecode': # for SAR model
config['Loss']['ignore_index'] = char_num - 1
model = build_model(config['Architecture']) model = build_model(config['Architecture'])
pre_best_model_dict = dict() pre_best_model_dict = dict()
......
...@@ -27,7 +27,7 @@ ...@@ -27,7 +27,7 @@
#### 2、CDLA数据集 #### 2、CDLA数据集
- **数据来源**:https://github.com/buptlihang/CDLA - **数据来源**:https://github.com/buptlihang/CDLA
- **数据简介**publaynet数据集的训练集合中包含5000张图像,验证集合中包含1000张图像。总共包含10个类别,分别是: `Text, Title, Figure, Figure caption, Table, Table caption, Header, Footer, Reference, Equation`。部分图像以及标注框可视化如下所示。 - **数据简介**CDLA据集的训练集合中包含5000张图像,验证集合中包含1000张图像。总共包含10个类别,分别是: `Text, Title, Figure, Figure caption, Table, Table caption, Header, Footer, Reference, Equation`。部分图像以及标注框可视化如下所示。
<div align="center"> <div align="center">
<img src="../datasets/CDLA_demo/val_0633.jpg" width="500"> <img src="../datasets/CDLA_demo/val_0633.jpg" width="500">
......
...@@ -72,6 +72,7 @@ def build_dataloader(config, mode, device, logger, seed=None): ...@@ -72,6 +72,7 @@ def build_dataloader(config, mode, device, logger, seed=None):
use_shared_memory = loader_config['use_shared_memory'] use_shared_memory = loader_config['use_shared_memory']
else: else:
use_shared_memory = True use_shared_memory = True
if mode == "Train": if mode == "Train":
# Distribute data to multiple cards # Distribute data to multiple cards
batch_sampler = DistributedBatchSampler( batch_sampler = DistributedBatchSampler(
......
...@@ -56,3 +56,17 @@ class ListCollator(object): ...@@ -56,3 +56,17 @@ class ListCollator(object):
for idx in to_tensor_idxs: for idx in to_tensor_idxs:
data_dict[idx] = paddle.to_tensor(data_dict[idx]) data_dict[idx] = paddle.to_tensor(data_dict[idx])
return list(data_dict.values()) return list(data_dict.values())
class SSLRotateCollate(object):
"""
bach: [
[(4*3xH*W), (4,)]
[(4*3xH*W), (4,)]
...
]
"""
def __call__(self, batch):
output = [np.concatenate(d, axis=0) for d in zip(*batch)]
return output
...@@ -24,6 +24,7 @@ from .make_pse_gt import MakePseGt ...@@ -24,6 +24,7 @@ from .make_pse_gt import MakePseGt
from .rec_img_aug import RecAug, RecConAug, RecResizeImg, ClsResizeImg, \ from .rec_img_aug import RecAug, RecConAug, RecResizeImg, ClsResizeImg, \
SRNRecResizeImg, NRTRRecResizeImg, SARRecResizeImg, PRENResizeImg, SVTRRecResizeImg SRNRecResizeImg, NRTRRecResizeImg, SARRecResizeImg, PRENResizeImg, SVTRRecResizeImg
from .ssl_img_aug import SSLRotateResize
from .randaugment import RandAugment from .randaugment import RandAugment
from .copy_paste import CopyPaste from .copy_paste import CopyPaste
from .ColorJitter import ColorJitter from .ColorJitter import ColorJitter
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import cv2
import numpy as np
import random
from PIL import Image
from .rec_img_aug import resize_norm_img
class SSLRotateResize(object):
def __init__(self,
image_shape,
padding=False,
select_all=True,
mode="train",
**kwargs):
self.image_shape = image_shape
self.padding = padding
self.select_all = select_all
self.mode = mode
def __call__(self, data):
img = data["image"]
data["image_r90"] = cv2.rotate(img, cv2.ROTATE_90_CLOCKWISE)
data["image_r180"] = cv2.rotate(data["image_r90"],
cv2.ROTATE_90_CLOCKWISE)
data["image_r270"] = cv2.rotate(data["image_r180"],
cv2.ROTATE_90_CLOCKWISE)
images = []
for key in ["image", "image_r90", "image_r180", "image_r270"]:
images.append(
resize_norm_img(
data.pop(key),
image_shape=self.image_shape,
padding=self.padding)[0])
data["image"] = np.stack(images, axis=0)
data["label"] = np.array(list(range(4)))
if not self.select_all:
data["image"] = data["image"][0::2] # just choose 0 and 180
data["label"] = data["label"][0:2] # label needs to be continuous
if self.mode == "test":
data["image"] = data["image"][0]
data["label"] = data["label"][0]
return data
...@@ -17,17 +17,26 @@ import paddle ...@@ -17,17 +17,26 @@ import paddle
class ClsPostProcess(object): class ClsPostProcess(object):
""" Convert between text-label and text-index """ """ Convert between text-label and text-index """
def __init__(self, label_list, **kwargs): def __init__(self, label_list=None, key=None, **kwargs):
super(ClsPostProcess, self).__init__() super(ClsPostProcess, self).__init__()
self.label_list = label_list self.label_list = label_list
self.key = key
def __call__(self, preds, label=None, *args, **kwargs): def __call__(self, preds, label=None, *args, **kwargs):
if self.key is not None:
preds = preds[self.key]
label_list = self.label_list
if label_list is None:
label_list = {idx: idx for idx in range(preds.shape[-1])}
if isinstance(preds, paddle.Tensor): if isinstance(preds, paddle.Tensor):
preds = preds.numpy() preds = preds.numpy()
pred_idxs = preds.argmax(axis=1) pred_idxs = preds.argmax(axis=1)
decode_out = [(self.label_list[idx], preds[i, idx]) decode_out = [(label_list[idx], preds[i, idx])
for i, idx in enumerate(pred_idxs)] for i, idx in enumerate(pred_idxs)]
if label is None: if label is None:
return decode_out return decode_out
label = [(self.label_list[idx], 1.0) for idx in label] label = [(label_list[idx], 1.0) for idx in label]
return decode_out, label return decode_out, label
...@@ -31,7 +31,7 @@ from ppocr.utils.logging import get_logger ...@@ -31,7 +31,7 @@ from ppocr.utils.logging import get_logger
from tools.program import load_config, merge_config, ArgsParser from tools.program import load_config, merge_config, ArgsParser
def export_single_model(model, arch_config, save_path, logger): def export_single_model(model, arch_config, save_path, logger, quanter=None):
if arch_config["algorithm"] == "SRN": if arch_config["algorithm"] == "SRN":
max_text_length = arch_config["Head"]["max_text_length"] max_text_length = arch_config["Head"]["max_text_length"]
other_shape = [ other_shape = [
...@@ -95,7 +95,10 @@ def export_single_model(model, arch_config, save_path, logger): ...@@ -95,7 +95,10 @@ def export_single_model(model, arch_config, save_path, logger):
shape=[None] + infer_shape, dtype="float32") shape=[None] + infer_shape, dtype="float32")
]) ])
paddle.jit.save(model, save_path) if quanter is None:
paddle.jit.save(model, save_path)
else:
quanter.save_quantized_model(model, save_path)
logger.info("inference model is saved to {}".format(save_path)) logger.info("inference model is saved to {}".format(save_path))
return return
...@@ -125,7 +128,6 @@ def main(): ...@@ -125,7 +128,6 @@ def main():
char_num = char_num - 2 char_num = char_num - 2
out_channels_list['CTCLabelDecode'] = char_num out_channels_list['CTCLabelDecode'] = char_num
out_channels_list['SARLabelDecode'] = char_num + 2 out_channels_list['SARLabelDecode'] = char_num + 2
loss_list = config['Loss']['loss_config_list']
config['Architecture']['Models'][key]['Head'][ config['Architecture']['Models'][key]['Head'][
'out_channels_list'] = out_channels_list 'out_channels_list'] = out_channels_list
else: else:
......
...@@ -57,6 +57,8 @@ def main(): ...@@ -57,6 +57,8 @@ def main():
continue continue
elif op_name == 'KeepKeys': elif op_name == 'KeepKeys':
op[op_name]['keep_keys'] = ['image'] op[op_name]['keep_keys'] = ['image']
elif op_name == "SSLRotateResize":
op[op_name]["mode"] = "test"
transforms.append(op) transforms.append(op)
global_config['infer_mode'] = True global_config['infer_mode'] = True
ops = create_operators(transforms, global_config) ops = create_operators(transforms, global_config)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册