提交 b063c417 编写于 作者: 锦鲤AI幸运's avatar 锦鲤AI幸运 🎯

Merge remote-tracking branch 'origin/dygraph' into dygraph

# Conflicts:
#	PPOCRLabel/libs/resources.py
...@@ -401,6 +401,7 @@ class MainWindow(QMainWindow, WindowMixin): ...@@ -401,6 +401,7 @@ class MainWindow(QMainWindow, WindowMixin):
help = action(getStr('tutorial'), self.showTutorialDialog, None, 'help', getStr('tutorialDetail')) help = action(getStr('tutorial'), self.showTutorialDialog, None, 'help', getStr('tutorialDetail'))
showInfo = action(getStr('info'), self.showInfoDialog, None, 'help', getStr('info')) showInfo = action(getStr('info'), self.showInfoDialog, None, 'help', getStr('info'))
showSteps = action(getStr('steps'), self.showStepsDialog, None, 'help', getStr('steps')) showSteps = action(getStr('steps'), self.showStepsDialog, None, 'help', getStr('steps'))
showKeys = action(getStr('keys'), self.showKeysDialog, None, 'help', getStr('keys'))
zoom = QWidgetAction(self) zoom = QWidgetAction(self)
zoom.setDefaultWidget(self.zoomWidget) zoom.setDefaultWidget(self.zoomWidget)
...@@ -568,7 +569,7 @@ class MainWindow(QMainWindow, WindowMixin): ...@@ -568,7 +569,7 @@ class MainWindow(QMainWindow, WindowMixin):
addActions(self.menus.file, addActions(self.menus.file,
(opendir, open_dataset_dir, None, saveLabel, saveRec, self.autoSaveOption, None, resetAll, deleteImg, quit)) (opendir, open_dataset_dir, None, saveLabel, saveRec, self.autoSaveOption, None, resetAll, deleteImg, quit))
addActions(self.menus.help, (showSteps, showInfo)) addActions(self.menus.help, (showKeys,showSteps, showInfo))
addActions(self.menus.view, ( addActions(self.menus.view, (
self.displayLabelOption, self.labelDialogOption, self.displayLabelOption, self.labelDialogOption,
None, None,
...@@ -763,6 +764,10 @@ class MainWindow(QMainWindow, WindowMixin): ...@@ -763,6 +764,10 @@ class MainWindow(QMainWindow, WindowMixin):
msg = stepsInfo(self.lang) msg = stepsInfo(self.lang)
QMessageBox.information(self, u'Information', msg) QMessageBox.information(self, u'Information', msg)
def showKeysDialog(self):
msg = keysInfo(self.lang)
QMessageBox.information(self, u'Information', msg)
def createShape(self): def createShape(self):
assert self.beginner() assert self.beginner()
self.canvas.setEditing(False) self.canvas.setEditing(False)
......
此差异已折叠。
...@@ -174,6 +174,7 @@ def stepsInfo(lang='en'): ...@@ -174,6 +174,7 @@ def stepsInfo(lang='en'):
"10. 标注结果:关闭应用程序或切换文件路径后,手动保存过的标签将会被存放在所打开图片文件夹下的" \ "10. 标注结果:关闭应用程序或切换文件路径后,手动保存过的标签将会被存放在所打开图片文件夹下的" \
"*Label.txt*中。在菜单栏点击 “PaddleOCR” - 保存识别结果后,会将此类图片的识别训练数据保存在*crop_img*文件夹下," \ "*Label.txt*中。在菜单栏点击 “PaddleOCR” - 保存识别结果后,会将此类图片的识别训练数据保存在*crop_img*文件夹下," \
"识别标签保存在*rec_gt.txt*中。\n" "识别标签保存在*rec_gt.txt*中。\n"
else: else:
msg = "1. Build and launch using the instructions above.\n" \ msg = "1. Build and launch using the instructions above.\n" \
"2. Click 'Open Dir' in Menu/File to select the folder of the picture.\n"\ "2. Click 'Open Dir' in Menu/File to select the folder of the picture.\n"\
...@@ -187,5 +188,57 @@ def stepsInfo(lang='en'): ...@@ -187,5 +188,57 @@ def stepsInfo(lang='en'):
"8. Click 'Save', the image status will switch to '√',then the program automatically jump to the next.\n"\ "8. Click 'Save', the image status will switch to '√',then the program automatically jump to the next.\n"\
"9. Click 'Delete Image' and the image will be deleted to the recycle bin.\n"\ "9. Click 'Delete Image' and the image will be deleted to the recycle bin.\n"\
"10. Labeling result: After closing the application or switching the file path, the manually saved label will be stored in *Label.txt* under the opened picture folder.\n"\ "10. Labeling result: After closing the application or switching the file path, the manually saved label will be stored in *Label.txt* under the opened picture folder.\n"\
" Click PaddleOCR-Save Recognition Results in the menu bar, the recognition training data of such pictures will be saved in the *crop_img* folder, and the recognition label will be saved in *rec_gt.txt*.\n" " Click PaddleOCR-Save Recognition Results in the menu bar, the recognition training data of such pictures will be saved in the *crop_img* folder, and the recognition label will be saved in *rec_gt.txt*.\n"
return msg
def keysInfo(lang='en'):
if lang == 'ch':
msg = "快捷键\t\t\t说明\n" \
"———————————————————————\n"\
"Ctrl + shift + R\t\t对当前图片的所有标记重新识别\n" \
"W\t\t\t新建矩形框\n" \
"Q\t\t\t新建四点框\n" \
"Ctrl + E\t\t编辑所选框标签\n" \
"Ctrl + R\t\t重新识别所选标记\n" \
"Ctrl + C\t\t复制并粘贴选中的标记框\n" \
"Ctrl + 鼠标左键\t\t多选标记框\n" \
"Backspace\t\t删除所选框\n" \
"Ctrl + V\t\t确认本张图片标记\n" \
"Ctrl + Shift + d\t删除本张图片\n" \
"D\t\t\t下一张图片\n" \
"A\t\t\t上一张图片\n" \
"Ctrl++\t\t\t缩小\n" \
"Ctrl--\t\t\t放大\n" \
"↑→↓←\t\t\t移动标记框\n" \
"———————————————————————\n" \
"注:Mac用户Command键替换上述Ctrl键"
else:
msg = "Shortcut Keys\t\tDescription\n" \
"———————————————————————\n" \
"Ctrl + shift + R\t\tRe-recognize all the labels\n" \
"\t\t\tof the current image\n" \
"\n"\
"W\t\t\tCreate a rect box\n" \
"Q\t\t\tCreate a four-points box\n" \
"Ctrl + E\t\tEdit label of the selected box\n" \
"Ctrl + R\t\tRe-recognize the selected box\n" \
"Ctrl + C\t\tCopy and paste the selected\n" \
"\t\t\tbox\n" \
"\n"\
"Ctrl + Left Mouse\tMulti select the label\n" \
"Button\t\t\tbox\n" \
"\n"\
"Backspace\t\tDelete the selected box\n" \
"Ctrl + V\t\tCheck image\n" \
"Ctrl + Shift + d\tDelete image\n" \
"D\t\t\tNext image\n" \
"A\t\t\tPrevious image\n" \
"Ctrl++\t\t\tZoom in\n" \
"Ctrl--\t\t\tZoom out\n" \
"↑→↓←\t\t\tMove selected box" \
"———————————————————————\n" \
"Notice:For Mac users, use the 'Command' key instead of the 'Ctrl' key"
return msg return msg
\ No newline at end of file
...@@ -90,6 +90,7 @@ saveRec=保存识别结果 ...@@ -90,6 +90,7 @@ saveRec=保存识别结果
tempLabel=待识别 tempLabel=待识别
nullLabel=无法识别 nullLabel=无法识别
steps=操作步骤 steps=操作步骤
keys=快捷键
choseModelLg=选择模型语言 choseModelLg=选择模型语言
cancel=取消 cancel=取消
ok=确认 ok=确认
......
...@@ -90,6 +90,7 @@ saveRec=Save Recognition Result ...@@ -90,6 +90,7 @@ saveRec=Save Recognition Result
tempLabel=TEMPORARY tempLabel=TEMPORARY
nullLabel=NULL nullLabel=NULL
steps=Steps steps=Steps
keys=Shortcut Keys
choseModelLg=Choose Model Language choseModelLg=Choose Model Language
cancel=Cancel cancel=Cancel
ok=OK ok=OK
......
Global:
use_gpu: true
epoch_num: 1200
log_smooth_window: 20
print_batch_step: 2
save_model_dir: ./output/ch_db_mv3/
save_epoch_step: 1200
# evaluation is run every 5000 iterations after the 4000th iteration
eval_batch_step: [3000, 2000]
cal_metric_during_train: False
pretrained_model: ./pretrain_models/MobileNetV3_large_x0_5_pretrained
checkpoints:
save_inference_dir:
use_visualdl: False
infer_img: doc/imgs_en/img_10.jpg
save_res_path: ./output/det_db/predicts_db.txt
Architecture:
name: DistillationModel
algorithm: Distillation
Models:
Student:
pretrained: ./pretrain_models/MobileNetV3_large_x0_5_pretrained
freeze_params: false
return_all_feats: false
model_type: det
algorithm: DB
Backbone:
name: MobileNetV3
scale: 0.5
model_name: large
disable_se: True
Neck:
name: DBFPN
out_channels: 96
Head:
name: DBHead
k: 50
Student2:
pretrained: ./pretrain_models/MobileNetV3_large_x0_5_pretrained
freeze_params: false
return_all_feats: false
model_type: det
algorithm: DB
Transform:
Backbone:
name: MobileNetV3
scale: 0.5
model_name: large
disable_se: True
Neck:
name: DBFPN
out_channels: 96
Head:
name: DBHead
k: 50
Teacher:
pretrained: ./pretrain_models/ch_ppocr_server_v2.0_det_train/best_accuracy
freeze_params: true
return_all_feats: false
model_type: det
algorithm: DB
Transform:
Backbone:
name: ResNet
layers: 18
Neck:
name: DBFPN
out_channels: 256
Head:
name: DBHead
k: 50
Loss:
name: CombinedLoss
loss_config_list:
- DistillationDilaDBLoss:
weight: 1.0
model_name_pairs:
- ["Student", "Teacher"]
- ["Student2", "Teacher"]
key: maps
balance_loss: true
main_loss_type: DiceLoss
alpha: 5
beta: 10
ohem_ratio: 3
- DistillationDMLLoss:
model_name_pairs:
- ["Student", "Student2"]
maps_name: "thrink_maps"
weight: 1.0
# act: None
model_name_pairs: ["Student", "Student2"]
key: maps
- DistillationDBLoss:
weight: 1.0
model_name_list: ["Student", "Student2"]
# key: maps
# name: DBLoss
balance_loss: true
main_loss_type: DiceLoss
alpha: 5
beta: 10
ohem_ratio: 3
Optimizer:
name: Adam
beta1: 0.9
beta2: 0.999
lr:
name: Cosine
learning_rate: 0.001
warmup_epoch: 2
regularizer:
name: 'L2'
factor: 0
PostProcess:
name: DistillationDBPostProcess
model_name: ["Student", "Student2", "Teacher"]
# key: maps
thresh: 0.3
box_thresh: 0.6
max_candidates: 1000
unclip_ratio: 1.5
Metric:
name: DistillationMetric
base_metric_name: DetMetric
main_indicator: hmean
key: "Student"
Train:
dataset:
name: SimpleDataSet
data_dir: ./train_data/icdar2015/text_localization/
label_file_list:
- ./train_data/icdar2015/text_localization/train_icdar2015_label.txt
ratio_list: [1.0]
transforms:
- DecodeImage: # load image
img_mode: BGR
channel_first: False
- DetLabelEncode: # Class handling label
- IaaAugment:
augmenter_args:
- { 'type': Fliplr, 'args': { 'p': 0.5 } }
- { 'type': Affine, 'args': { 'rotate': [-10, 10] } }
- { 'type': Resize, 'args': { 'size': [0.5, 3] } }
- EastRandomCropData:
size: [960, 960]
max_tries: 50
keep_ratio: true
- MakeBorderMap:
shrink_ratio: 0.4
thresh_min: 0.3
thresh_max: 0.7
- MakeShrinkMap:
shrink_ratio: 0.4
min_text_size: 8
- NormalizeImage:
scale: 1./255.
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: 'hwc'
- ToCHWImage:
- KeepKeys:
keep_keys: ['image', 'threshold_map', 'threshold_mask', 'shrink_map', 'shrink_mask'] # the order of the dataloader list
loader:
shuffle: True
drop_last: False
batch_size_per_card: 8
num_workers: 4
Eval:
dataset:
name: SimpleDataSet
data_dir: ./train_data/icdar2015/text_localization/
label_file_list:
- ./train_data/icdar2015/text_localization/test_icdar2015_label.txt
transforms:
- DecodeImage: # load image
img_mode: BGR
channel_first: False
- DetLabelEncode: # Class handling label
- DetResizeForTest:
# image_shape: [736, 1280]
- NormalizeImage:
scale: 1./255.
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: 'hwc'
- ToCHWImage:
- KeepKeys:
keep_keys: ['image', 'shape', 'polys', 'ignore_tags']
loader:
shuffle: False
drop_last: False
batch_size_per_card: 1 # must be 1
num_workers: 2
Global:
use_gpu: true
epoch_num: 1200
log_smooth_window: 20
print_batch_step: 2
save_model_dir: ./output/ch_db_mv3/
save_epoch_step: 1200
# evaluation is run every 5000 iterations after the 4000th iteration
eval_batch_step: [3000, 2000]
cal_metric_during_train: False
pretrained_model: ./pretrain_models/MobileNetV3_large_x0_5_pretrained
checkpoints:
save_inference_dir:
use_visualdl: False
infer_img: doc/imgs_en/img_10.jpg
save_res_path: ./output/det_db/predicts_db.txt
Architecture:
name: DistillationModel
algorithm: Distillation
Models:
Student:
pretrained: ./pretrain_models/MobileNetV3_large_x0_5_pretrained
freeze_params: false
return_all_feats: false
model_type: det
algorithm: DB
Backbone:
name: MobileNetV3
scale: 0.5
model_name: large
disable_se: True
Neck:
name: DBFPN
out_channels: 96
Head:
name: DBHead
k: 50
Teacher:
pretrained: ./pretrain_models/ch_ppocr_server_v2.0_det_train/best_accuracy
freeze_params: true
return_all_feats: false
model_type: det
algorithm: DB
Transform:
Backbone:
name: ResNet
layers: 18
Neck:
name: DBFPN
out_channels: 256
Head:
name: DBHead
k: 50
Loss:
name: CombinedLoss
loss_config_list:
- DistillationDilaDBLoss:
weight: 1.0
model_name_pairs:
- ["Student", "Teacher"]
key: maps
balance_loss: true
main_loss_type: DiceLoss
alpha: 5
beta: 10
ohem_ratio: 3
- DistillationDBLoss:
weight: 1.0
model_name_list: ["Student", "Teacher"]
# key: maps
name: DBLoss
balance_loss: true
main_loss_type: DiceLoss
alpha: 5
beta: 10
ohem_ratio: 3
Optimizer:
name: Adam
beta1: 0.9
beta2: 0.999
lr:
name: Cosine
learning_rate: 0.001
warmup_epoch: 2
regularizer:
name: 'L2'
factor: 0
PostProcess:
name: DistillationDBPostProcess
model_name: ["Student", "Student2"]
key: head_out
thresh: 0.3
box_thresh: 0.6
max_candidates: 1000
unclip_ratio: 1.5
Metric:
name: DistillationMetric
base_metric_name: DetMetric
main_indicator: hmean
key: "Student"
Train:
dataset:
name: SimpleDataSet
data_dir: ./train_data/icdar2015/text_localization/
label_file_list:
- ./train_data/icdar2015/text_localization/train_icdar2015_label.txt
ratio_list: [1.0]
transforms:
- DecodeImage: # load image
img_mode: BGR
channel_first: False
- DetLabelEncode: # Class handling label
- IaaAugment:
augmenter_args:
- { 'type': Fliplr, 'args': { 'p': 0.5 } }
- { 'type': Affine, 'args': { 'rotate': [-10, 10] } }
- { 'type': Resize, 'args': { 'size': [0.5, 3] } }
- EastRandomCropData:
size: [960, 960]
max_tries: 50
keep_ratio: true
- MakeBorderMap:
shrink_ratio: 0.4
thresh_min: 0.3
thresh_max: 0.7
- MakeShrinkMap:
shrink_ratio: 0.4
min_text_size: 8
- NormalizeImage:
scale: 1./255.
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: 'hwc'
- ToCHWImage:
- KeepKeys:
keep_keys: ['image', 'threshold_map', 'threshold_mask', 'shrink_map', 'shrink_mask'] # the order of the dataloader list
loader:
shuffle: True
drop_last: False
batch_size_per_card: 8
num_workers: 4
Eval:
dataset:
name: SimpleDataSet
data_dir: ./train_data/icdar2015/text_localization/
label_file_list:
- ./train_data/icdar2015/text_localization/test_icdar2015_label.txt
transforms:
- DecodeImage: # load image
img_mode: BGR
channel_first: False
- DetLabelEncode: # Class handling label
- DetResizeForTest:
# image_shape: [736, 1280]
- NormalizeImage:
scale: 1./255.
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: 'hwc'
- ToCHWImage:
- KeepKeys:
keep_keys: ['image', 'shape', 'polys', 'ignore_tags']
loader:
shuffle: False
drop_last: False
batch_size_per_card: 1 # must be 1
num_workers: 2
Global:
use_gpu: true
epoch_num: 1200
log_smooth_window: 20
print_batch_step: 2
save_model_dir: ./output/ch_db_mv3/
save_epoch_step: 1200
# evaluation is run every 5000 iterations after the 4000th iteration
eval_batch_step: [3000, 2000]
cal_metric_during_train: False
pretrained_model: ./pretrain_models/MobileNetV3_large_x0_5_pretrained
checkpoints:
save_inference_dir:
use_visualdl: False
infer_img: doc/imgs_en/img_10.jpg
save_res_path: ./output/det_db/predicts_db.txt
Architecture:
name: DistillationModel
algorithm: Distillation
Models:
Student:
pretrained: ./pretrain_models/MobileNetV3_large_x0_5_pretrained
freeze_params: false
return_all_feats: false
model_type: det
algorithm: DB
Backbone:
name: MobileNetV3
scale: 0.5
model_name: large
disable_se: True
Neck:
name: DBFPN
out_channels: 96
Head:
name: DBHead
k: 50
Student2:
pretrained: ./pretrain_models/MobileNetV3_large_x0_5_pretrained
freeze_params: false
return_all_feats: false
model_type: det
algorithm: DB
Transform:
Backbone:
name: MobileNetV3
scale: 0.5
model_name: large
disable_se: True
Neck:
name: DBFPN
out_channels: 96
Head:
name: DBHead
k: 50
Loss:
name: CombinedLoss
loss_config_list:
- DistillationDMLLoss:
model_name_pairs:
- ["Student", "Student2"]
maps_name: "thrink_maps"
weight: 1.0
act: "softmax"
model_name_pairs: ["Student", "Student2"]
key: maps
- DistillationDBLoss:
weight: 1.0
model_name_list: ["Student", "Student2"]
# key: maps
name: DBLoss
balance_loss: true
main_loss_type: DiceLoss
alpha: 5
beta: 10
ohem_ratio: 3
Optimizer:
name: Adam
beta1: 0.9
beta2: 0.999
lr:
name: Cosine
learning_rate: 0.001
warmup_epoch: 2
regularizer:
name: 'L2'
factor: 0
PostProcess:
name: DistillationDBPostProcess
model_name: ["Student", "Student2"]
key: head_out
thresh: 0.3
box_thresh: 0.6
max_candidates: 1000
unclip_ratio: 1.5
Metric:
name: DistillationMetric
base_metric_name: DetMetric
main_indicator: hmean
key: "Student"
Train:
dataset:
name: SimpleDataSet
data_dir: ./train_data/icdar2015/text_localization/
label_file_list:
- ./train_data/icdar2015/text_localization/train_icdar2015_label.txt
ratio_list: [1.0]
transforms:
- DecodeImage: # load image
img_mode: BGR
channel_first: False
- DetLabelEncode: # Class handling label
- IaaAugment:
augmenter_args:
- { 'type': Fliplr, 'args': { 'p': 0.5 } }
- { 'type': Affine, 'args': { 'rotate': [-10, 10] } }
- { 'type': Resize, 'args': { 'size': [0.5, 3] } }
- EastRandomCropData:
size: [960, 960]
max_tries: 50
keep_ratio: true
- MakeBorderMap:
shrink_ratio: 0.4
thresh_min: 0.3
thresh_max: 0.7
- MakeShrinkMap:
shrink_ratio: 0.4
min_text_size: 8
- NormalizeImage:
scale: 1./255.
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: 'hwc'
- ToCHWImage:
- KeepKeys:
keep_keys: ['image', 'threshold_map', 'threshold_mask', 'shrink_map', 'shrink_mask'] # the order of the dataloader list
loader:
shuffle: True
drop_last: False
batch_size_per_card: 8
num_workers: 4
Eval:
dataset:
name: SimpleDataSet
data_dir: ./train_data/icdar2015/text_localization/
label_file_list:
- ./train_data/icdar2015/text_localization/test_icdar2015_label.txt
transforms:
- DecodeImage: # load image
img_mode: BGR
channel_first: False
- DetLabelEncode: # Class handling label
- DetResizeForTest:
# image_shape: [736, 1280]
- NormalizeImage:
scale: 1./255.
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: 'hwc'
- ToCHWImage:
- KeepKeys:
keep_keys: ['image', 'shape', 'polys', 'ignore_tags']
loader:
shuffle: False
drop_last: False
batch_size_per_card: 1 # must be 1
num_workers: 2
...@@ -13,7 +13,6 @@ SET(TENSORRT_DIR "" CACHE PATH "Compile demo with TensorRT") ...@@ -13,7 +13,6 @@ SET(TENSORRT_DIR "" CACHE PATH "Compile demo with TensorRT")
set(DEMO_NAME "ocr_system") set(DEMO_NAME "ocr_system")
macro(safe_set_static_flag) macro(safe_set_static_flag)
foreach(flag_var foreach(flag_var
CMAKE_CXX_FLAGS CMAKE_CXX_FLAGS_DEBUG CMAKE_CXX_FLAGS_RELEASE CMAKE_CXX_FLAGS CMAKE_CXX_FLAGS_DEBUG CMAKE_CXX_FLAGS_RELEASE
......
...@@ -668,7 +668,7 @@ void DisposeOutPts(OutPt *&pp) { ...@@ -668,7 +668,7 @@ void DisposeOutPts(OutPt *&pp) {
//------------------------------------------------------------------------------ //------------------------------------------------------------------------------
inline void InitEdge(TEdge *e, TEdge *eNext, TEdge *ePrev, const IntPoint &Pt) { inline void InitEdge(TEdge *e, TEdge *eNext, TEdge *ePrev, const IntPoint &Pt) {
std::memset(e, 0, sizeof(TEdge)); std::memset(e, int(0), sizeof(TEdge));
e->Next = eNext; e->Next = eNext;
e->Prev = ePrev; e->Prev = ePrev;
e->Curr = Pt; e->Curr = Pt;
...@@ -1895,17 +1895,17 @@ void Clipper::InsertLocalMinimaIntoAEL(const cInt botY) { ...@@ -1895,17 +1895,17 @@ void Clipper::InsertLocalMinimaIntoAEL(const cInt botY) {
TEdge *rb = lm->RightBound; TEdge *rb = lm->RightBound;
OutPt *Op1 = 0; OutPt *Op1 = 0;
if (!lb) { if (!lb || !rb) {
// nb: don't insert LB into either AEL or SEL // nb: don't insert LB into either AEL or SEL
InsertEdgeIntoAEL(rb, 0); InsertEdgeIntoAEL(rb, 0);
SetWindingCount(*rb); SetWindingCount(*rb);
if (IsContributing(*rb)) if (IsContributing(*rb))
Op1 = AddOutPt(rb, rb->Bot); Op1 = AddOutPt(rb, rb->Bot);
} else if (!rb) { //} else if (!rb) {
InsertEdgeIntoAEL(lb, 0); // InsertEdgeIntoAEL(lb, 0);
SetWindingCount(*lb); // SetWindingCount(*lb);
if (IsContributing(*lb)) // if (IsContributing(*lb))
Op1 = AddOutPt(lb, lb->Bot); // Op1 = AddOutPt(lb, lb->Bot);
InsertScanbeam(lb->Top.Y); InsertScanbeam(lb->Top.Y);
} else { } else {
InsertEdgeIntoAEL(lb, 0); InsertEdgeIntoAEL(lb, 0);
...@@ -2547,13 +2547,13 @@ void Clipper::ProcessHorizontal(TEdge *horzEdge) { ...@@ -2547,13 +2547,13 @@ void Clipper::ProcessHorizontal(TEdge *horzEdge) {
if (dir == dLeftToRight) { if (dir == dLeftToRight) {
maxIt = m_Maxima.begin(); maxIt = m_Maxima.begin();
while (maxIt != m_Maxima.end() && *maxIt <= horzEdge->Bot.X) while (maxIt != m_Maxima.end() && *maxIt <= horzEdge->Bot.X)
maxIt++; ++maxIt;
if (maxIt != m_Maxima.end() && *maxIt >= eLastHorz->Top.X) if (maxIt != m_Maxima.end() && *maxIt >= eLastHorz->Top.X)
maxIt = m_Maxima.end(); maxIt = m_Maxima.end();
} else { } else {
maxRit = m_Maxima.rbegin(); maxRit = m_Maxima.rbegin();
while (maxRit != m_Maxima.rend() && *maxRit > horzEdge->Bot.X) while (maxRit != m_Maxima.rend() && *maxRit > horzEdge->Bot.X)
maxRit++; ++maxRit;
if (maxRit != m_Maxima.rend() && *maxRit <= eLastHorz->Top.X) if (maxRit != m_Maxima.rend() && *maxRit <= eLastHorz->Top.X)
maxRit = m_Maxima.rend(); maxRit = m_Maxima.rend();
} }
...@@ -2576,13 +2576,13 @@ void Clipper::ProcessHorizontal(TEdge *horzEdge) { ...@@ -2576,13 +2576,13 @@ void Clipper::ProcessHorizontal(TEdge *horzEdge) {
while (maxIt != m_Maxima.end() && *maxIt < e->Curr.X) { while (maxIt != m_Maxima.end() && *maxIt < e->Curr.X) {
if (horzEdge->OutIdx >= 0 && !IsOpen) if (horzEdge->OutIdx >= 0 && !IsOpen)
AddOutPt(horzEdge, IntPoint(*maxIt, horzEdge->Bot.Y)); AddOutPt(horzEdge, IntPoint(*maxIt, horzEdge->Bot.Y));
maxIt++; ++maxIt;
} }
} else { } else {
while (maxRit != m_Maxima.rend() && *maxRit > e->Curr.X) { while (maxRit != m_Maxima.rend() && *maxRit > e->Curr.X) {
if (horzEdge->OutIdx >= 0 && !IsOpen) if (horzEdge->OutIdx >= 0 && !IsOpen)
AddOutPt(horzEdge, IntPoint(*maxRit, horzEdge->Bot.Y)); AddOutPt(horzEdge, IntPoint(*maxRit, horzEdge->Bot.Y));
maxRit++; ++maxRit;
} }
} }
}; };
......
...@@ -21,10 +21,10 @@ std::vector<std::string> OCRConfig::split(const std::string &str, ...@@ -21,10 +21,10 @@ std::vector<std::string> OCRConfig::split(const std::string &str,
std::vector<std::string> res; std::vector<std::string> res;
if ("" == str) if ("" == str)
return res; return res;
char *strs = new char[str.length() + 1]; char strs[str.length() + 1];
std::strcpy(strs, str.c_str()); std::strcpy(strs, str.c_str());
char *d = new char[delim.length() + 1]; char d[delim.length() + 1];
std::strcpy(d, delim.c_str()); std::strcpy(d, delim.c_str());
char *p = std::strtok(strs, d); char *p = std::strtok(strs, d);
...@@ -61,4 +61,4 @@ void OCRConfig::PrintConfigInfo() { ...@@ -61,4 +61,4 @@ void OCRConfig::PrintConfigInfo() {
std::cout << "=======End of Paddle OCR inference config======" << std::endl; std::cout << "=======End of Paddle OCR inference config======" << std::endl;
} }
} // namespace PaddleOCR } // namespace PaddleOCR
\ No newline at end of file
...@@ -147,12 +147,12 @@ python3 tools/infer/predict_det.py --image_dir="./doc/imgs/00018069.jpg" --det_m ...@@ -147,12 +147,12 @@ python3 tools/infer/predict_det.py --image_dir="./doc/imgs/00018069.jpg" --det_m
如果输入图片的分辨率比较大,而且想使用更大的分辨率预测,可以设置det_limit_side_len 为想要的值,比如1216: 如果输入图片的分辨率比较大,而且想使用更大的分辨率预测,可以设置det_limit_side_len 为想要的值,比如1216:
``` ```
python3 tools/infer/predict_det.py --image_dir="./doc/imgs/2.jpg" --det_model_dir="./inference/det_db/" --det_limit_type=max --det_limit_side_len=1216 python3 tools/infer/predict_det.py --image_dir="./doc/imgs/1.jpg" --det_model_dir="./inference/det_db/" --det_limit_type=max --det_limit_side_len=1216
``` ```
如果想使用CPU进行预测,执行命令如下 如果想使用CPU进行预测,执行命令如下
``` ```
python3 tools/infer/predict_det.py --image_dir="./doc/imgs/2.jpg" --det_model_dir="./inference/det_db/" --use_gpu=False python3 tools/infer/predict_det.py --image_dir="./doc/imgs/1.jpg" --det_model_dir="./inference/det_db/" --use_gpu=False
``` ```
<a name="DB文本检测模型推理"></a> <a name="DB文本检测模型推理"></a>
......
...@@ -154,12 +154,12 @@ Set as `limit_type='min', det_limit_side_len=960`, it means that the shortest si ...@@ -154,12 +154,12 @@ Set as `limit_type='min', det_limit_side_len=960`, it means that the shortest si
If the resolution of the input picture is relatively large and you want to use a larger resolution prediction, you can set det_limit_side_len to the desired value, such as 1216: If the resolution of the input picture is relatively large and you want to use a larger resolution prediction, you can set det_limit_side_len to the desired value, such as 1216:
``` ```
python3 tools/infer/predict_det.py --image_dir="./doc/imgs/22.jpg" --det_model_dir="./inference/det_db/" --det_limit_type=max --det_limit_side_len=1216 python3 tools/infer/predict_det.py --image_dir="./doc/imgs/1.jpg" --det_model_dir="./inference/det_db/" --det_limit_type=max --det_limit_side_len=1216
``` ```
If you want to use the CPU for prediction, execute the command as follows If you want to use the CPU for prediction, execute the command as follows
``` ```
python3 tools/infer/predict_det.py --image_dir="./doc/imgs/22.jpg" --det_model_dir="./inference/det_db/" --use_gpu=False python3 tools/infer/predict_det.py --image_dir="./doc/imgs/1.jpg" --det_model_dir="./inference/det_db/" --use_gpu=False
``` ```
<a name="DB_DETECTION"></a> <a name="DB_DETECTION"></a>
......
...@@ -15,8 +15,6 @@ ...@@ -15,8 +15,6 @@
- 2020.6.8 Add [datasets](./datasets_en.md) and keep updating - 2020.6.8 Add [datasets](./datasets_en.md) and keep updating
- 2020.6.5 Support exporting `attention` model to `inference_model` - 2020.6.5 Support exporting `attention` model to `inference_model`
- 2020.6.5 Support separate prediction and recognition, output result score - 2020.6.5 Support separate prediction and recognition, output result score
- 2020.6.5 Support exporting `attention` model to `inference_model`
- 2020.6.5 Support separate prediction and recognition, output result score
- 2020.5.30 Provide Lightweight Chinese OCR online experience - 2020.5.30 Provide Lightweight Chinese OCR online experience
- 2020.5.30 Model prediction and training support on Windows system - 2020.5.30 Model prediction and training support on Windows system
- 2020.5.30 Open source general Chinese OCR model - 2020.5.30 Open source general Chinese OCR model
......
doc/joinus.PNG

187.8 KB | W: | H:

doc/joinus.PNG

189.2 KB | W: | H:

doc/joinus.PNG
doc/joinus.PNG
doc/joinus.PNG
doc/joinus.PNG
  • 2-up
  • Swipe
  • Onion skin
...@@ -46,6 +46,7 @@ class SimpleDataSet(Dataset): ...@@ -46,6 +46,7 @@ class SimpleDataSet(Dataset):
self.seed = seed self.seed = seed
logger.info("Initialize indexs of datasets:%s" % label_file_list) logger.info("Initialize indexs of datasets:%s" % label_file_list)
self.data_lines = self.get_image_info_list(label_file_list, ratio_list) self.data_lines = self.get_image_info_list(label_file_list, ratio_list)
self.check_data()
self.data_idx_order_list = list(range(len(self.data_lines))) self.data_idx_order_list = list(range(len(self.data_lines)))
if self.mode == "train" and self.do_shuffle: if self.mode == "train" and self.do_shuffle:
self.shuffle_data_random() self.shuffle_data_random()
...@@ -102,16 +103,8 @@ class SimpleDataSet(Dataset): ...@@ -102,16 +103,8 @@ class SimpleDataSet(Dataset):
def __getitem__(self, idx): def __getitem__(self, idx):
file_idx = self.data_idx_order_list[idx] file_idx = self.data_idx_order_list[idx]
data_line = self.data_lines[file_idx] data = self.data_lines[file_idx]
try: try:
data_line = data_line.decode('utf-8')
substr = data_line.strip("\n").strip("\r").split(self.delimiter)
file_name = substr[0]
label = substr[1]
img_path = os.path.join(self.data_dir, file_name)
data = {'img_path': img_path, 'label': label}
if not os.path.exists(img_path):
raise Exception("{} does not exist!".format(img_path))
with open(data['img_path'], 'rb') as f: with open(data['img_path'], 'rb') as f:
img = f.read() img = f.read()
data['image'] = img data['image'] = img
...@@ -120,8 +113,8 @@ class SimpleDataSet(Dataset): ...@@ -120,8 +113,8 @@ class SimpleDataSet(Dataset):
except: except:
error_meg = traceback.format_exc() error_meg = traceback.format_exc()
self.logger.error( self.logger.error(
"When parsing line {}, error happened with msg: {}".format( "When parsing file {} and label {}, error happened with msg: {}".format(
data_line, error_meg)) data['img_path'],data['label'], error_meg))
outs = None outs = None
if outs is None: if outs is None:
# during evaluation, we should fix the idx to get same results for many times of evaluation. # during evaluation, we should fix the idx to get same results for many times of evaluation.
...@@ -132,3 +125,17 @@ class SimpleDataSet(Dataset): ...@@ -132,3 +125,17 @@ class SimpleDataSet(Dataset):
def __len__(self): def __len__(self):
return len(self.data_idx_order_list) return len(self.data_idx_order_list)
def check_data(self):
new_data_lines = []
for data_line in self.data_lines:
data_line = data_line.decode('utf-8')
substr = data_line.strip("\n").strip("\r").split(self.delimiter)
file_name = substr[0]
label = substr[1]
img_path = os.path.join(self.data_dir, file_name)
if os.path.exists(img_path):
new_data_lines.append({'img_path': img_path, 'label': label})
else:
self.logger.info("{} does not exist!".format(img_path))
self.data_lines = new_data_lines
\ No newline at end of file
...@@ -54,6 +54,27 @@ class CELoss(nn.Layer): ...@@ -54,6 +54,27 @@ class CELoss(nn.Layer):
return loss return loss
class KLJSLoss(object):
def __init__(self, mode='kl'):
assert mode in ['kl', 'js', 'KL', 'JS'], "mode can only be one of ['kl', 'js', 'KL', 'JS']"
self.mode = mode
def __call__(self, p1, p2, reduction="mean"):
loss = paddle.multiply(p2, paddle.log( (p2+1e-5)/(p1+1e-5) + 1e-5))
if self.mode.lower() == "js":
loss += paddle.multiply(p1, paddle.log((p1+1e-5)/(p2+1e-5) + 1e-5))
loss *= 0.5
if reduction == "mean":
loss = paddle.mean(loss, axis=[1,2])
elif reduction=="none" or reduction is None:
return loss
else:
loss = paddle.sum(loss, axis=[1,2])
return loss
class DMLLoss(nn.Layer): class DMLLoss(nn.Layer):
""" """
DMLLoss DMLLoss
...@@ -69,17 +90,21 @@ class DMLLoss(nn.Layer): ...@@ -69,17 +90,21 @@ class DMLLoss(nn.Layer):
self.act = nn.Sigmoid() self.act = nn.Sigmoid()
else: else:
self.act = None self.act = None
self.jskl_loss = KLJSLoss(mode="js")
def forward(self, out1, out2): def forward(self, out1, out2):
if self.act is not None: if self.act is not None:
out1 = self.act(out1) out1 = self.act(out1)
out2 = self.act(out2) out2 = self.act(out2)
if len(out1.shape) < 2:
log_out1 = paddle.log(out1) log_out1 = paddle.log(out1)
log_out2 = paddle.log(out2) log_out2 = paddle.log(out2)
loss = (F.kl_div( loss = (F.kl_div(
log_out1, out2, reduction='batchmean') + F.kl_div( log_out1, out2, reduction='batchmean') + F.kl_div(
log_out2, out1, reduction='batchmean')) / 2.0 log_out2, out1, reduction='batchmean')) / 2.0
else:
loss = self.jskl_loss(out1, out2)
return loss return loss
......
...@@ -17,7 +17,7 @@ import paddle.nn as nn ...@@ -17,7 +17,7 @@ import paddle.nn as nn
from .distillation_loss import DistillationCTCLoss from .distillation_loss import DistillationCTCLoss
from .distillation_loss import DistillationDMLLoss from .distillation_loss import DistillationDMLLoss
from .distillation_loss import DistillationDistanceLoss from .distillation_loss import DistillationDistanceLoss, DistillationDBLoss, DistillationDilaDBLoss
class CombinedLoss(nn.Layer): class CombinedLoss(nn.Layer):
...@@ -44,15 +44,16 @@ class CombinedLoss(nn.Layer): ...@@ -44,15 +44,16 @@ class CombinedLoss(nn.Layer):
def forward(self, input, batch, **kargs): def forward(self, input, batch, **kargs):
loss_dict = {} loss_dict = {}
loss_all = 0.
for idx, loss_func in enumerate(self.loss_func): for idx, loss_func in enumerate(self.loss_func):
loss = loss_func(input, batch, **kargs) loss = loss_func(input, batch, **kargs)
if isinstance(loss, paddle.Tensor): if isinstance(loss, paddle.Tensor):
loss = {"loss_{}_{}".format(str(loss), idx): loss} loss = {"loss_{}_{}".format(str(loss), idx): loss}
weight = self.loss_weight[idx] weight = self.loss_weight[idx]
loss = { for key in loss.keys():
"{}_{}".format(key, idx): loss[key] * weight if key == "loss":
for key in loss loss_all += loss[key] * weight
} else:
loss_dict.update(loss) loss_dict["{}_{}".format(key, idx)] = loss[key]
loss_dict["loss"] = paddle.add_n(list(loss_dict.values())) loss_dict["loss"] = loss_all
return loss_dict return loss_dict
...@@ -14,23 +14,76 @@ ...@@ -14,23 +14,76 @@
import paddle import paddle
import paddle.nn as nn import paddle.nn as nn
import numpy as np
import cv2
from .rec_ctc_loss import CTCLoss from .rec_ctc_loss import CTCLoss
from .basic_loss import DMLLoss from .basic_loss import DMLLoss
from .basic_loss import DistanceLoss from .basic_loss import DistanceLoss
from .det_db_loss import DBLoss
from .det_basic_loss import BalanceLoss, MaskL1Loss, DiceLoss
def _sum_loss(loss_dict):
if "loss" in loss_dict.keys():
return loss_dict
else:
loss_dict["loss"] = 0.
for k, value in loss_dict.items():
if k == "loss":
continue
else:
loss_dict["loss"] += value
return loss_dict
class DistillationDMLLoss(DMLLoss): class DistillationDMLLoss(DMLLoss):
""" """
""" """
def __init__(self, model_name_pairs=[], act=None, key=None, def __init__(self,
name="loss_dml"): model_name_pairs=[],
act=None,
key=None,
maps_name=None,
name="dml"):
super().__init__(act=act) super().__init__(act=act)
assert isinstance(model_name_pairs, list) assert isinstance(model_name_pairs, list)
self.key = key self.key = key
self.model_name_pairs = model_name_pairs self.model_name_pairs = self._check_model_name_pairs(model_name_pairs)
self.name = name self.name = name
self.maps_name = self._check_maps_name(maps_name)
def _check_model_name_pairs(self, model_name_pairs):
if not isinstance(model_name_pairs, list):
return []
elif isinstance(model_name_pairs[0], list) and isinstance(model_name_pairs[0][0], str):
return model_name_pairs
else:
return [model_name_pairs]
def _check_maps_name(self, maps_name):
if maps_name is None:
return None
elif type(maps_name) == str:
return [maps_name]
elif type(maps_name) == list:
return [maps_name]
else:
return None
def _slice_out(self, outs):
new_outs = {}
for k in self.maps_name:
if k == "thrink_maps":
new_outs[k] = outs[:, 0, :, :]
elif k == "threshold_maps":
new_outs[k] = outs[:, 1, :, :]
elif k == "binary_maps":
new_outs[k] = outs[:, 2, :, :]
else:
continue
return new_outs
def forward(self, predicts, batch): def forward(self, predicts, batch):
loss_dict = dict() loss_dict = dict()
...@@ -40,13 +93,30 @@ class DistillationDMLLoss(DMLLoss): ...@@ -40,13 +93,30 @@ class DistillationDMLLoss(DMLLoss):
if self.key is not None: if self.key is not None:
out1 = out1[self.key] out1 = out1[self.key]
out2 = out2[self.key] out2 = out2[self.key]
loss = super().forward(out1, out2)
if isinstance(loss, dict): if self.maps_name is None:
for key in loss: loss = super().forward(out1, out2)
loss_dict["{}_{}_{}_{}".format(key, pair[0], pair[1], if isinstance(loss, dict):
idx)] = loss[key] for key in loss:
loss_dict["{}_{}_{}_{}".format(key, pair[0], pair[1],
idx)] = loss[key]
else:
loss_dict["{}_{}".format(self.name, idx)] = loss
else: else:
loss_dict["{}_{}".format(self.name, idx)] = loss outs1 = self._slice_out(out1)
outs2 = self._slice_out(out2)
for _c, k in enumerate(outs1.keys()):
loss = super().forward(outs1[k], outs2[k])
if isinstance(loss, dict):
for key in loss:
loss_dict["{}_{}_{}_{}_{}".format(key, pair[
0], pair[1], map_name, idx)] = loss[key]
else:
loss_dict["{}_{}_{}".format(self.name, self.maps_name[_c],
idx)] = loss
loss_dict = _sum_loss(loss_dict)
return loss_dict return loss_dict
...@@ -73,6 +143,98 @@ class DistillationCTCLoss(CTCLoss): ...@@ -73,6 +143,98 @@ class DistillationCTCLoss(CTCLoss):
return loss_dict return loss_dict
class DistillationDBLoss(DBLoss):
def __init__(self,
model_name_list=[],
balance_loss=True,
main_loss_type='DiceLoss',
alpha=5,
beta=10,
ohem_ratio=3,
eps=1e-6,
name="db",
**kwargs):
super().__init__()
self.model_name_list = model_name_list
self.name = name
self.key = None
def forward(self, predicts, batch):
loss_dict = {}
for idx, model_name in enumerate(self.model_name_list):
out = predicts[model_name]
if self.key is not None:
out = out[self.key]
loss = super().forward(out, batch)
if isinstance(loss, dict):
for key in loss.keys():
if key == "loss":
continue
name = "{}_{}_{}".format(self.name, model_name, key)
loss_dict[name] = loss[key]
else:
loss_dict["{}_{}".format(self.name, model_name)] = loss
loss_dict = _sum_loss(loss_dict)
return loss_dict
class DistillationDilaDBLoss(DBLoss):
def __init__(self,
model_name_pairs=[],
key=None,
balance_loss=True,
main_loss_type='DiceLoss',
alpha=5,
beta=10,
ohem_ratio=3,
eps=1e-6,
name="dila_dbloss"):
super().__init__()
self.model_name_pairs = model_name_pairs
self.name = name
self.key = key
def forward(self, predicts, batch):
loss_dict = dict()
for idx, pair in enumerate(self.model_name_pairs):
stu_outs = predicts[pair[0]]
tch_outs = predicts[pair[1]]
if self.key is not None:
stu_preds = stu_outs[self.key]
tch_preds = tch_outs[self.key]
stu_shrink_maps = stu_preds[:, 0, :, :]
stu_binary_maps = stu_preds[:, 2, :, :]
# dilation to teacher prediction
dilation_w = np.array([[1, 1], [1, 1]])
th_shrink_maps = tch_preds[:, 0, :, :]
th_shrink_maps = th_shrink_maps.numpy() > 0.3 # thresh = 0.3
dilate_maps = np.zeros_like(th_shrink_maps).astype(np.float32)
for i in range(th_shrink_maps.shape[0]):
dilate_maps[i] = cv2.dilate(
th_shrink_maps[i, :, :].astype(np.uint8), dilation_w)
th_shrink_maps = paddle.to_tensor(dilate_maps)
label_threshold_map, label_threshold_mask, label_shrink_map, label_shrink_mask = batch[
1:]
# calculate the shrink map loss
bce_loss = self.alpha * self.bce_loss(
stu_shrink_maps, th_shrink_maps, label_shrink_mask)
loss_binary_maps = self.dice_loss(stu_binary_maps, th_shrink_maps,
label_shrink_mask)
# k = f"{self.name}_{pair[0]}_{pair[1]}"
k = "{}_{}_{}".format(self.name, pair[0], pair[1])
loss_dict[k] = bce_loss + loss_binary_maps
loss_dict = _sum_loss(loss_dict)
return loss_dict
class DistillationDistanceLoss(DistanceLoss): class DistillationDistanceLoss(DistanceLoss):
""" """
""" """
......
...@@ -55,6 +55,7 @@ class DetMetric(object): ...@@ -55,6 +55,7 @@ class DetMetric(object):
result = self.evaluator.evaluate_image(gt_info_list, det_info_list) result = self.evaluator.evaluate_image(gt_info_list, det_info_list)
self.results.append(result) self.results.append(result)
def get_metric(self): def get_metric(self):
""" """
return metrics { return metrics {
......
...@@ -24,8 +24,8 @@ from .cls_metric import ClsMetric ...@@ -24,8 +24,8 @@ from .cls_metric import ClsMetric
class DistillationMetric(object): class DistillationMetric(object):
def __init__(self, def __init__(self,
key=None, key=None,
base_metric_name="RecMetric", base_metric_name=None,
main_indicator='acc', main_indicator=None,
**kwargs): **kwargs):
self.main_indicator = main_indicator self.main_indicator = main_indicator
self.key = key self.key = key
...@@ -42,16 +42,13 @@ class DistillationMetric(object): ...@@ -42,16 +42,13 @@ class DistillationMetric(object):
main_indicator=self.main_indicator, **self.kwargs) main_indicator=self.main_indicator, **self.kwargs)
self.metrics[key].reset() self.metrics[key].reset()
def __call__(self, preds, *args, **kwargs): def __call__(self, preds, batch, **kwargs):
assert isinstance(preds, dict) assert isinstance(preds, dict)
if self.metrics is None: if self.metrics is None:
self._init_metrcis(preds) self._init_metrcis(preds)
output = dict() output = dict()
for key in preds: for key in preds:
metric = self.metrics[key].__call__(preds[key], *args, **kwargs) self.metrics[key].__call__(preds[key], batch, **kwargs)
for sub_key in metric:
output["{}_{}".format(key, sub_key)] = metric[sub_key]
return output
def get_metric(self): def get_metric(self):
""" """
......
...@@ -79,7 +79,10 @@ class BaseModel(nn.Layer): ...@@ -79,7 +79,10 @@ class BaseModel(nn.Layer):
x = self.neck(x) x = self.neck(x)
y["neck_out"] = x y["neck_out"] = x
x = self.head(x, targets=data) x = self.head(x, targets=data)
y["head_out"] = x if isinstance(x, dict):
y.update(x)
else:
y["head_out"] = x
if self.return_all_feats: if self.return_all_feats:
return y return y
else: else:
......
...@@ -21,7 +21,7 @@ from ppocr.modeling.backbones import build_backbone ...@@ -21,7 +21,7 @@ from ppocr.modeling.backbones import build_backbone
from ppocr.modeling.necks import build_neck from ppocr.modeling.necks import build_neck
from ppocr.modeling.heads import build_head from ppocr.modeling.heads import build_head
from .base_model import BaseModel from .base_model import BaseModel
from ppocr.utils.save_load import init_model from ppocr.utils.save_load import init_model, load_pretrained_params
__all__ = ['DistillationModel'] __all__ = ['DistillationModel']
...@@ -46,7 +46,7 @@ class DistillationModel(nn.Layer): ...@@ -46,7 +46,7 @@ class DistillationModel(nn.Layer):
pretrained = model_config.pop("pretrained") pretrained = model_config.pop("pretrained")
model = BaseModel(model_config) model = BaseModel(model_config)
if pretrained is not None: if pretrained is not None:
init_model(model, path=pretrained) model = load_pretrained_params(model, pretrained)
if freeze_params: if freeze_params:
for param in model.parameters(): for param in model.parameters():
param.trainable = False param.trainable = False
......
...@@ -21,7 +21,7 @@ import copy ...@@ -21,7 +21,7 @@ import copy
__all__ = ['build_post_process'] __all__ = ['build_post_process']
from .db_postprocess import DBPostProcess from .db_postprocess import DBPostProcess, DistillationDBPostProcess
from .east_postprocess import EASTPostProcess from .east_postprocess import EASTPostProcess
from .sast_postprocess import SASTPostProcess from .sast_postprocess import SASTPostProcess
from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, DistillationCTCLabelDecode, \ from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, DistillationCTCLabelDecode, \
...@@ -34,7 +34,8 @@ def build_post_process(config, global_config=None): ...@@ -34,7 +34,8 @@ def build_post_process(config, global_config=None):
support_dict = [ support_dict = [
'DBPostProcess', 'EASTPostProcess', 'SASTPostProcess', 'CTCLabelDecode', 'DBPostProcess', 'EASTPostProcess', 'SASTPostProcess', 'CTCLabelDecode',
'AttnLabelDecode', 'ClsPostProcess', 'SRNLabelDecode', 'PGPostProcess', 'AttnLabelDecode', 'ClsPostProcess', 'SRNLabelDecode', 'PGPostProcess',
'DistillationCTCLabelDecode', 'TableLabelDecode' 'DistillationCTCLabelDecode', 'TableLabelDecode',
'DistillationDBPostProcess'
] ]
config = copy.deepcopy(config) config = copy.deepcopy(config)
......
...@@ -187,3 +187,29 @@ class DBPostProcess(object): ...@@ -187,3 +187,29 @@ class DBPostProcess(object):
boxes_batch.append({'points': boxes}) boxes_batch.append({'points': boxes})
return boxes_batch return boxes_batch
class DistillationDBPostProcess(object):
def __init__(self, model_name=["student"],
key=None,
thresh=0.3,
box_thresh=0.6,
max_candidates=1000,
unclip_ratio=1.5,
use_dilation=False,
score_mode="fast",
**kwargs):
self.model_name = model_name
self.key = key
self.post_process = DBPostProcess(thresh=thresh,
box_thresh=box_thresh,
max_candidates=max_candidates,
unclip_ratio=unclip_ratio,
use_dilation=use_dilation,
score_mode=score_mode)
def __call__(self, predicts, shape_list):
results = {}
for k in self.model_name:
results[k] = self.post_process(predicts[k], shape_list=shape_list)
return results
...@@ -116,6 +116,27 @@ def load_dygraph_params(config, model, logger, optimizer): ...@@ -116,6 +116,27 @@ def load_dygraph_params(config, model, logger, optimizer):
logger.info(f"loaded pretrained_model successful from {pm}") logger.info(f"loaded pretrained_model successful from {pm}")
return {} return {}
def load_pretrained_params(model, path):
if path is None:
return False
if not os.path.exists(path) and not os.path.exists(path + ".pdparams"):
print(f"The pretrained_model {path} does not exists!")
return False
path = path if path.endswith('.pdparams') else path + '.pdparams'
params = paddle.load(path)
state_dict = model.state_dict()
new_state_dict = {}
for k1, k2 in zip(state_dict.keys(), params.keys()):
if list(state_dict[k1].shape) == list(params[k2].shape):
new_state_dict[k1] = params[k2]
else:
print(
f"The shape of model params {k1} {state_dict[k1].shape} not matched with loaded params {k2} {params[k2].shape} !"
)
model.set_state_dict(new_state_dict)
print(f"load pretrain successful from {path}")
return model
def save_model(model, def save_model(model,
optimizer, optimizer,
......
model_name:ocr_rec
python:python
gpu_list:0|0,1
Global.auto_cast:null
Global.epoch_num:10
Global.save_model_dir:./output/
Train.loader.batch_size_per_card:
Global.use_gpu:
Global.pretrained_model:null
trainer:norm|pact
norm_train:tools/train.py -c configs/rec/rec_mv3_none_bilstm_ctc.yml
quant_train:deploy/slim/quantization/quant.py -c configs/rec/rec_mv3_none_bilstm_ctc.yml
fpgm_train:null
distill_train:null
eval:tools/eval.py -c configs/rec/rec_mv3_none_bilstm_ctc.yml -o
Global.save_inference_dir:./output/
Global.pretrained_model:
norm_export:tools/export_model.py -c configs/rec/rec_mv3_none_bilstm_ctc.yml -o
quant_export:deploy/slim/quantization/export_model.py -c configs/rec/rec_mv3_none_bilstm_ctc.yml -o
fpgm_export:null
distill_export:null
inference:tools/infer/predict_rec.py
--use_gpu:True|False
--enable_mkldnn:True|False
--cpu_threads:1|6
--rec_batch_num:1
--use_tensorrt:True|False
--precision:fp32|fp16|int8
--rec_model_dir:./inference/ch_ppocr_mobile_v2.0_rec_infer/
--image_dir:./inference/rec_inference
--save_log_path:./test/output/
\ No newline at end of file
...@@ -29,19 +29,21 @@ train_model_list=$(func_parser_value "${lines[0]}") ...@@ -29,19 +29,21 @@ train_model_list=$(func_parser_value "${lines[0]}")
trainer_list=$(func_parser_value "${lines[10]}") trainer_list=$(func_parser_value "${lines[10]}")
# MODE be one of ['lite_train_infer' 'whole_infer' 'whole_train_infer'] # MODE be one of ['lite_train_infer' 'whole_infer' 'whole_train_infer']
MODE=$2 MODE=$2
# prepare pretrained weights and dataset # prepare pretrained weights and dataset
wget -nc -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MobileNetV3_large_x0_5_pretrained.pdparams if [ ${train_model_list[*]} = "ocr_det" ]; then
wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_mv3_db_v2.0_train.tar wget -nc -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MobileNetV3_large_x0_5_pretrained.pdparams
cd pretrain_models && tar xf det_mv3_db_v2.0_train.tar && cd ../ wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_mv3_db_v2.0_train.tar
cd pretrain_models && tar xf det_mv3_db_v2.0_train.tar && cd ../
fi
if [ ${MODE} = "lite_train_infer" ];then if [ ${MODE} = "lite_train_infer" ];then
# pretrain lite train data # pretrain lite train data
rm -rf ./train_data/icdar2015 rm -rf ./train_data/icdar2015
wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/icdar2015_lite.tar wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/icdar2015_lite.tar
cd ./train_data/ && tar xf icdar2015_lite.tar wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/ic15_data.tar # todo change to bcebos
cd ./train_data/ && tar xf icdar2015_lite.tar && tar xf ic15_data.tar
ln -s ./icdar2015_lite ./icdar2015 ln -s ./icdar2015_lite ./icdar2015
cd ../ cd ../
epoch=10 epoch=10
...@@ -49,13 +51,15 @@ if [ ${MODE} = "lite_train_infer" ];then ...@@ -49,13 +51,15 @@ if [ ${MODE} = "lite_train_infer" ];then
elif [ ${MODE} = "whole_train_infer" ];then elif [ ${MODE} = "whole_train_infer" ];then
rm -rf ./train_data/icdar2015 rm -rf ./train_data/icdar2015
wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/icdar2015.tar wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/icdar2015.tar
cd ./train_data/ && tar xf icdar2015.tar && cd ../ wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/ic15_data.tar
cd ./train_data/ && tar xf icdar2015.tar && tar xf ic15_data.tar && cd ../
epoch=500 epoch=500
eval_batch_step=200 eval_batch_step=200
elif [ ${MODE} = "whole_infer" ];then elif [ ${MODE} = "whole_infer" ];then
rm -rf ./train_data/icdar2015 rm -rf ./train_data/icdar2015
wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/icdar2015_infer.tar wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/icdar2015_infer.tar
cd ./train_data/ && tar xf icdar2015_infer.tar wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/ic15_data.tar
cd ./train_data/ && tar xf icdar2015_infer.tar && tar xf ic15_data.tar
ln -s ./icdar2015_infer ./icdar2015 ln -s ./icdar2015_infer ./icdar2015
cd ../ cd ../
epoch=10 epoch=10
...@@ -88,9 +92,11 @@ for train_model in ${train_model_list[*]}; do ...@@ -88,9 +92,11 @@ for train_model in ${train_model_list[*]}; do
elif [ ${train_model} = "ocr_rec" ];then elif [ ${train_model} = "ocr_rec" ];then
model_name="ocr_rec" model_name="ocr_rec"
yml_file="configs/rec/rec_mv3_none_bilstm_ctc.yml" yml_file="configs/rec/rec_mv3_none_bilstm_ctc.yml"
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/ch_rec_data_200.tar wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/rec_inference.tar
cd ./inference && tar xf ch_rec_data_200.tar && cd ../ cd ./inference && tar xf rec_inference.tar && cd ../
img_dir="./inference/ch_rec_data_200/" img_dir="./inference/rec_inference/"
data_dir=./inference/rec_inference
data_label_file=[./inference/rec_inference/rec_gt_test.txt]
fi fi
# eval # eval
......
...@@ -27,7 +27,7 @@ from ppocr.data import build_dataloader ...@@ -27,7 +27,7 @@ from ppocr.data import build_dataloader
from ppocr.modeling.architectures import build_model from ppocr.modeling.architectures import build_model
from ppocr.postprocess import build_post_process from ppocr.postprocess import build_post_process
from ppocr.metrics import build_metric from ppocr.metrics import build_metric
from ppocr.utils.save_load import init_model from ppocr.utils.save_load import init_model, load_pretrained_params
from ppocr.utils.utility import print_dict from ppocr.utils.utility import print_dict
import tools.program as program import tools.program as program
...@@ -55,7 +55,10 @@ def main(): ...@@ -55,7 +55,10 @@ def main():
model = build_model(config['Architecture']) model = build_model(config['Architecture'])
use_srn = config['Architecture']['algorithm'] == "SRN" use_srn = config['Architecture']['algorithm'] == "SRN"
model_type = config['Architecture']['model_type'] if "model_type" in config['Architecture'].keys():
model_type = config['Architecture']['model_type']
else:
model_type = None
best_model_dict = init_model(config, model) best_model_dict = init_model(config, model)
if len(best_model_dict): if len(best_model_dict):
...@@ -68,7 +71,7 @@ def main(): ...@@ -68,7 +71,7 @@ def main():
# start eval # start eval
metric = program.eval(model, valid_dataloader, post_process_class, metric = program.eval(model, valid_dataloader, post_process_class,
eval_class, model_type, use_srn) eval_class, model_type, use_srn)
logger.info('metric eval ***************') logger.info('metric eval ***************')
for k, v in metric.items(): for k, v in metric.items():
logger.info('{}:{}'.format(k, v)) logger.info('{}:{}'.format(k, v))
......
...@@ -112,7 +112,6 @@ class TextClassifier(object): ...@@ -112,7 +112,6 @@ class TextClassifier(object):
if '180' in label and score > self.cls_thresh: if '180' in label and score > self.cls_thresh:
img_list[indices[beg_img_no + rno]] = cv2.rotate( img_list[indices[beg_img_no + rno]] = cv2.rotate(
img_list[indices[beg_img_no + rno]], 1) img_list[indices[beg_img_no + rno]], 1)
elapse = time.time() - starttime
return img_list, cls_res, elapse return img_list, cls_res, elapse
...@@ -146,7 +145,6 @@ def main(args): ...@@ -146,7 +145,6 @@ def main(args):
cls_res[ino])) cls_res[ino]))
logger.info( logger.info(
"The predict time about text angle classify module is as follows: ") "The predict time about text angle classify module is as follows: ")
text_classifier.cls_times.info(average=False)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -64,6 +64,24 @@ class TextRecognizer(object): ...@@ -64,6 +64,24 @@ class TextRecognizer(object):
self.postprocess_op = build_post_process(postprocess_params) self.postprocess_op = build_post_process(postprocess_params)
self.predictor, self.input_tensor, self.output_tensors, self.config = \ self.predictor, self.input_tensor, self.output_tensors, self.config = \
utility.create_predictor(args, 'rec', logger) utility.create_predictor(args, 'rec', logger)
self.benchmark = args.benchmark
if args.benchmark:
import auto_log
pid = os.getpid()
self.autolog = auto_log.AutoLogger(
model_name="rec",
model_precision=args.precision,
batch_size=args.rec_batch_num,
data_shape="dynamic",
save_path=args.save_log_path,
inference_config=self.config,
pids=pid,
process_name=None,
gpu_ids=0 if args.use_gpu else None,
time_keys=[
'preprocess_time', 'inference_time', 'postprocess_time'
],
warmup=10)
def resize_norm_img(self, img, max_wh_ratio): def resize_norm_img(self, img, max_wh_ratio):
imgC, imgH, imgW = self.rec_image_shape imgC, imgH, imgW = self.rec_image_shape
...@@ -168,6 +186,8 @@ class TextRecognizer(object): ...@@ -168,6 +186,8 @@ class TextRecognizer(object):
rec_res = [['', 0.0]] * img_num rec_res = [['', 0.0]] * img_num
batch_num = self.rec_batch_num batch_num = self.rec_batch_num
st = time.time() st = time.time()
if self.benchmark:
self.autolog.times.start()
for beg_img_no in range(0, img_num, batch_num): for beg_img_no in range(0, img_num, batch_num):
end_img_no = min(img_num, beg_img_no + batch_num) end_img_no = min(img_num, beg_img_no + batch_num)
norm_img_batch = [] norm_img_batch = []
...@@ -196,6 +216,8 @@ class TextRecognizer(object): ...@@ -196,6 +216,8 @@ class TextRecognizer(object):
norm_img_batch.append(norm_img[0]) norm_img_batch.append(norm_img[0])
norm_img_batch = np.concatenate(norm_img_batch) norm_img_batch = np.concatenate(norm_img_batch)
norm_img_batch = norm_img_batch.copy() norm_img_batch = norm_img_batch.copy()
if self.benchmark:
self.autolog.times.stamp()
if self.rec_algorithm == "SRN": if self.rec_algorithm == "SRN":
encoder_word_pos_list = np.concatenate(encoder_word_pos_list) encoder_word_pos_list = np.concatenate(encoder_word_pos_list)
...@@ -222,6 +244,8 @@ class TextRecognizer(object): ...@@ -222,6 +244,8 @@ class TextRecognizer(object):
for output_tensor in self.output_tensors: for output_tensor in self.output_tensors:
output = output_tensor.copy_to_cpu() output = output_tensor.copy_to_cpu()
outputs.append(output) outputs.append(output)
if self.benchmark:
self.autolog.times.stamp()
preds = {"predict": outputs[2]} preds = {"predict": outputs[2]}
else: else:
self.input_tensor.copy_from_cpu(norm_img_batch) self.input_tensor.copy_from_cpu(norm_img_batch)
...@@ -231,11 +255,14 @@ class TextRecognizer(object): ...@@ -231,11 +255,14 @@ class TextRecognizer(object):
for output_tensor in self.output_tensors: for output_tensor in self.output_tensors:
output = output_tensor.copy_to_cpu() output = output_tensor.copy_to_cpu()
outputs.append(output) outputs.append(output)
if self.benchmark:
self.autolog.times.stamp()
preds = outputs[0] preds = outputs[0]
rec_result = self.postprocess_op(preds) rec_result = self.postprocess_op(preds)
for rno in range(len(rec_result)): for rno in range(len(rec_result)):
rec_res[indices[beg_img_no + rno]] = rec_result[rno] rec_res[indices[beg_img_no + rno]] = rec_result[rno]
if self.benchmark:
self.autolog.times.end(stamp=True)
return rec_res, time.time() - st return rec_res, time.time() - st
...@@ -251,9 +278,6 @@ def main(args): ...@@ -251,9 +278,6 @@ def main(args):
for i in range(10): for i in range(10):
res = text_recognizer([img]) res = text_recognizer([img])
cpu_mem, gpu_mem, gpu_util = 0, 0, 0
count = 0
for image_file in image_file_list: for image_file in image_file_list:
img, flag = check_and_read_gif(image_file) img, flag = check_and_read_gif(image_file)
if not flag: if not flag:
...@@ -273,6 +297,8 @@ def main(args): ...@@ -273,6 +297,8 @@ def main(args):
for ino in range(len(img_list)): for ino in range(len(img_list)):
logger.info("Predicts of {}:{}".format(valid_image_file_list[ino], logger.info("Predicts of {}:{}".format(valid_image_file_list[ino],
rec_res[ino])) rec_res[ino]))
if args.benchmark:
text_recognizer.autolog.report()
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -24,9 +24,6 @@ from paddle import inference ...@@ -24,9 +24,6 @@ from paddle import inference
import time import time
from ppocr.utils.logging import get_logger from ppocr.utils.logging import get_logger
logger = get_logger()
def str2bool(v): def str2bool(v):
return v.lower() in ("true", "t", "1") return v.lower() in ("true", "t", "1")
......
...@@ -186,7 +186,10 @@ def train(config, ...@@ -186,7 +186,10 @@ def train(config,
model.train() model.train()
use_srn = config['Architecture']['algorithm'] == "SRN" use_srn = config['Architecture']['algorithm'] == "SRN"
model_type = config['Architecture']['model_type'] try:
model_type = config['Architecture']['model_type']
except:
model_type = None
if 'start_epoch' in best_model_dict: if 'start_epoch' in best_model_dict:
start_epoch = best_model_dict['start_epoch'] start_epoch = best_model_dict['start_epoch']
......
...@@ -98,7 +98,6 @@ def main(config, device, logger, vdl_writer): ...@@ -98,7 +98,6 @@ def main(config, device, logger, vdl_writer):
eval_class = build_metric(config['Metric']) eval_class = build_metric(config['Metric'])
# load pretrain model # load pretrain model
pre_best_model_dict = load_dygraph_params(config, model, logger, optimizer) pre_best_model_dict = load_dygraph_params(config, model, logger, optimizer)
logger.info('train dataloader has {} iters'.format(len(train_dataloader))) logger.info('train dataloader has {} iters'.format(len(train_dataloader)))
if valid_dataloader is not None: if valid_dataloader is not None:
logger.info('valid dataloader has {} iters'.format( logger.info('valid dataloader has {} iters'.format(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册