未验证 提交 56c6c3ae 编写于 作者: X xiaoting 提交者: GitHub

Merge pull request #1 from LDOUBLEV/upload

Upload PaddleOCR code 
- repo: https://github.com/PaddlePaddle/mirrors-yapf.git
sha: 0d79c0c469bab64f7229c9aca2b1186ef47f0e37
hooks:
- id: yapf
files: \.py$
- repo: https://github.com/pre-commit/pre-commit-hooks
sha: a11d9314b22d8f8c7556443875b731ef05965464
hooks:
- id: check-merge-conflict
- id: check-symlinks
- id: detect-private-key
files: (?!.*paddle)^.*$
- id: end-of-file-fixer
files: \.md$
- id: trailing-whitespace
files: \.md$
- repo: https://github.com/Lucas-C/pre-commit-hooks
sha: v1.0.1
hooks:
- id: forbid-crlf
files: \.md$
- id: remove-crlf
files: \.md$
- id: forbid-tabs
files: \.md$
- id: remove-tabs
files: \.md$
- repo: local
hooks:
- id: clang-format
name: clang-format
description: Format files with ClangFormat
entry: bash .clang_format.hook -i
language: system
files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|cuh|proto)$
[style]
based_on_style = pep8
column_limit = 80
TrainReader:
reader_function: ppocr.data.det.dataset_traversal,TrainReader
process_function: ppocr.data.det.db_process,DBProcessTrain
num_workers: 8
img_set_dir: ./train_data/icdar2015/text_localization/
label_file_path: ./train_data/icdar2015/text_localization/train_icdar2015_label.txt
EvalReader:
reader_function: ppocr.data.det.dataset_traversal,EvalTestReader
process_function: ppocr.data.det.db_process,DBProcessTest
img_set_dir: ./train_data/icdar2015/text_localization/
label_file_path: ./train_data/icdar2015/text_localization/test_icdar2015_label.txt
test_image_shape: [736, 1280]
TestReader:
reader_function: ppocr.data.det.dataset_traversal,EvalTestReader
process_function: ppocr.data.det.db_process,DBProcessTest
single_img_path:
img_set_dir: ./train_data/icdar2015/text_localization/
label_file_path: ./train_data/icdar2015/text_localization/test_icdar2015_label.txt
test_image_shape: [736, 1280]
do_eval: True
Global:
algorithm: DB
use_gpu: true
epoch_num: 1200
log_smooth_window: 20
print_batch_step: 2
save_model_dir: output
save_epoch_step: 200
eval_batch_step: 5000
train_batch_size_per_card: 16
test_batch_size_per_card: 16
image_shape: [3, 640, 640]
reader_yml: ./configs/det/det_db_icdar15_reader.yml
pretrain_weights: ./pretrain_models/MobileNetV3_pretrained/MobileNetV3_large_x0_5_pretrained/
save_res_path: ./output/predicts_db.txt
Architecture:
function: ppocr.modeling.architectures.det_model,DetModel
Backbone:
function: ppocr.modeling.backbones.det_mobilenet_v3,MobileNetV3
scale: 0.5
model_name: large
Head:
function: ppocr.modeling.heads.det_db_head,DBHead
model_name: large
k: 50
inner_channels: 96
out_channels: 2
Loss:
function: ppocr.modeling.losses.det_db_loss,DBLoss
balance_loss: true
main_loss_type: DiceLoss
alpha: 5
beta: 10
ohem_ratio: 3
Optimizer:
function: ppocr.optimizer,AdamDecay
base_lr: 0.001
beta1: 0.9
beta2: 0.999
PostProcess:
function: ppocr.postprocess.db_postprocess,DBPostProcess
thresh: 0.3
box_thresh: 0.7
max_candidates: 1000
unclip_ratio: 1.5
\ No newline at end of file
Global:
algorithm: DB
use_gpu: true
epoch_num: 1200
log_smooth_window: 20
print_batch_step: 2
save_model_dir: output
save_epoch_step: 200
eval_batch_step: 5000
train_batch_size_per_card: 8
test_batch_size_per_card: 16
image_shape: [3, 640, 640]
reader_yml: ./configs/det/det_db_icdar15_reader.yml
pretrain_weights: ./pretrain_models/ResNet50_vd_pretrained/
save_res_path: ./output/predicts_db.txt
Architecture:
function: ppocr.modeling.architectures.det_model,DetModel
Backbone:
function: ppocr.modeling.backbones.det_resnet_vd,ResNet
layers: 50
Head:
function: ppocr.modeling.heads.det_db_head,DBHead
model_name: large
k: 50
inner_channels: 256
out_channels: 2
Loss:
function: ppocr.modeling.losses.det_db_loss,DBLoss
balance_loss: true
main_loss_type: DiceLoss
alpha: 5
beta: 10
ohem_ratio: 3
Optimizer:
function: ppocr.optimizer,AdamDecay
base_lr: 0.001
beta1: 0.9
beta2: 0.999
PostProcess:
function: ppocr.postprocess.db_postprocess,DBPostProcess
thresh: 0.3
box_thresh: 0.7
max_candidates: 1000
unclip_ratio: 1.5
\ No newline at end of file
TrainReader:
reader_function: ppocr.data.det.dataset_traversal,TrainReader
process_function: ppocr.data.det.east_process,EASTProcessTrain
num_workers: 8
img_set_dir: ./train_data/icdar2015/text_localization/
label_file_path: ./train_data/icdar2015/text_localization/train_icdar2015_label.txt
background_ratio: 0.125
min_crop_side_ratio: 0.1
min_text_size: 10
EvalReader:
reader_function: ppocr.data.det.dataset_traversal,EvalTestReader
process_function: ppocr.data.det.east_process,EASTProcessTest
img_set_dir: ./train_data/icdar2015/text_localization/
label_file_path: ./train_data/icdar2015/text_localization/test_icdar2015_label.txt
TestReader:
reader_function: ppocr.data.det.dataset_traversal,EvalTestReader
process_function: ppocr.data.det.east_process,EASTProcessTest
single_img_path:
img_set_dir: ./train_data/icdar2015/text_localization/
label_file_path: ./train_data/icdar2015/text_localization/test_icdar2015_label.txt
do_eval: True
Global:
algorithm: EAST
use_gpu: true
epoch_num: 100000
log_smooth_window: 20
print_batch_step: 5
save_model_dir: output
save_epoch_step: 200
eval_batch_step: 5000
train_batch_size_per_card: 16
test_batch_size_per_card: 16
image_shape: [3, 512, 512]
reader_yml: ./configs/det/det_east_icdar15_reader.yml
pretrain_weights: ./pretrain_models/MobileNetV3_pretrained/MobileNetV3_large_x0_5_pretrained/
save_res_path: ./output/predicts_east.txt
Architecture:
function: ppocr.modeling.architectures.det_model,DetModel
Backbone:
function: ppocr.modeling.backbones.det_mobilenet_v3,MobileNetV3
scale: 0.5
model_name: large
Head:
function: ppocr.modeling.heads.det_east_head,EASTHead
model_name: small
Loss:
function: ppocr.modeling.losses.det_east_loss,EASTLoss
Optimizer:
function: ppocr.optimizer,AdamDecay
base_lr: 0.001
beta1: 0.9
beta2: 0.999
PostProcess:
function: ppocr.postprocess.east_postprocess,EASTPostPocess
score_thresh: 0.8
cover_thresh: 0.1
nms_thresh: 0.2
\ No newline at end of file
Global:
algorithm: EAST
use_gpu: true
epoch_num: 100000
log_smooth_window: 20
print_batch_step: 5
save_model_dir: output
save_epoch_step: 200
eval_batch_step: 5000
train_batch_size_per_card: 8
test_batch_size_per_card: 16
image_shape: [3, 512, 512]
reader_yml: ./configs/det/det_east_icdar15_reader.yml
pretrain_weights: ./pretrain_models/ResNet50_vd_pretrained/
save_res_path: ./output/predicts_east.txt
Architecture:
function: ppocr.modeling.architectures.det_model,DetModel
Backbone:
function: ppocr.modeling.backbones.det_resnet_vd,ResNet
layers: 50
Head:
function: ppocr.modeling.heads.det_east_head,EASTHead
model_name: large
Loss:
function: ppocr.modeling.losses.det_east_loss,EASTLoss
Optimizer:
function: ppocr.optimizer,AdamDecay
base_lr: 0.001
beta1: 0.9
beta2: 0.999
PostProcess:
function: ppocr.postprocess.east_postprocess,EASTPostPocess
score_thresh: 0.8
cover_thresh: 0.1
nms_thresh: 0.2
\ No newline at end of file
TrainReader:
reader_function: ppocr.data.rec.dataset_traversal,LMDBReader
num_workers: 8
lmdb_sets_dir: ./train_data/data_lmdb_release/training/
EvalReader:
reader_function: ppocr.data.rec.dataset_traversal,LMDBReader
lmdb_sets_dir: ./train_data/data_lmdb_release/validation/
TestReader:
reader_function: ppocr.data.rec.dataset_traversal,LMDBReader
lmdb_sets_dir: ./train_data/data_lmdb_release/evaluation/
\ No newline at end of file
Global:
algorithm: CRNN
dataset: common
use_gpu: true
epoch_num: 300
log_smooth_window: 20
print_batch_step: 10
save_model_dir: output
save_epoch_step: 3
eval_batch_step: 2000
train_batch_size_per_card: 256
test_batch_size_per_card: 256
image_shape: [3, 32, 100]
max_text_length: 25
character_type: ch
character_dict_path: ./ppocr/utils/ppocr_keys_v1.txt
loss_type: ctc
reader_yml: ./configs/rec/rec_chinese_reader.yml
pretrain_weights:
Architecture:
function: ppocr.modeling.architectures.rec_model,RecModel
Backbone:
function: ppocr.modeling.backbones.rec_mobilenet_v3,MobileNetV3
scale: 0.5
model_name: small
Head:
function: ppocr.modeling.heads.rec_ctc_head,CTCPredict
encoder_type: rnn
SeqRNN:
hidden_size: 48
Loss:
function: ppocr.modeling.losses.rec_ctc_loss,CTCLoss
Optimizer:
function: ppocr.optimizer,AdamDecay
base_lr: 0.001
beta1: 0.9
beta2: 0.999
TrainReader:
reader_function: ppocr.data.rec.dataset_traversal,SimpleReader
num_workers: 8
img_set_dir: .
label_file_path: ./train_data/hard_label.txt
EvalReader:
reader_function: ppocr.data.rec.dataset_traversal,SimpleReader
img_set_dir: .
label_file_path: ./train_data/label_val_all.txt
TestReader:
reader_function: ppocr.data.rec.dataset_traversal,SimpleReader
infer_img: ./infer_img
Global:
algorithm: CRNN
use_gpu: true
epoch_num: 72
log_smooth_window: 20
print_batch_step: 10
save_model_dir: output
save_epoch_step: 3
eval_batch_step: 2000
train_batch_size_per_card: 256
test_batch_size_per_card: 256
image_shape: [3, 32, 100]
max_text_length: 25
character_type: en
loss_type: ctc
reader_yml: ./configs/rec/rec_benchmark_reader.yml
pretrain_weights:
Architecture:
function: ppocr.modeling.architectures.rec_model,RecModel
Backbone:
function: ppocr.modeling.backbones.rec_mobilenet_v3,MobileNetV3
scale: 0.5
model_name: large
Head:
function: ppocr.modeling.heads.rec_ctc_head,CTCPredict
encoder_type: rnn
SeqRNN:
hidden_size: 96
Loss:
function: ppocr.modeling.losses.rec_ctc_loss,CTCLoss
Optimizer:
function: ppocr.optimizer,AdamDecay
base_lr: 0.001
beta1: 0.9
beta2: 0.999
Global:
algorithm: Rosetta
use_gpu: true
epoch_num: 72
log_smooth_window: 20
print_batch_step: 10
save_model_dir: output
save_epoch_step: 3
eval_batch_step: 2000
train_batch_size_per_card: 256
test_batch_size_per_card: 256
image_shape: [3, 32, 100]
max_text_length: 25
character_type: en
loss_type: ctc
reader_yml: ./configs/rec/rec_benchmark_reader.yml
pretrain_weights:
Architecture:
function: ppocr.modeling.architectures.rec_model,RecModel
Backbone:
function: ppocr.modeling.backbones.rec_mobilenet_v3,MobileNetV3
scale: 0.5
model_name: large
Head:
function: ppocr.modeling.heads.rec_ctc_head,CTCPredict
encoder_type: reshape
Loss:
function: ppocr.modeling.losses.rec_ctc_loss,CTCLoss
Optimizer:
function: ppocr.optimizer,AdamDecay
base_lr: 0.001
beta1: 0.9
beta2: 0.999
Global:
algorithm: RARE
use_gpu: true
epoch_num: 72
log_smooth_window: 20
print_batch_step: 10
save_model_dir: output
save_epoch_step: 3
eval_batch_step: 2000
train_batch_size_per_card: 256
test_batch_size_per_card: 256
image_shape: [3, 32, 100]
max_text_length: 25
character_type: en
loss_type: attention
reader_yml: ./configs/rec/rec_benchmark_reader.yml
pretrain_weights:
Architecture:
function: ppocr.modeling.architectures.rec_model,RecModel
TPS:
function: ppocr.modeling.stns.tps,TPS
num_fiducial: 20
loc_lr: 0.1
model_name: small
Backbone:
function: ppocr.modeling.backbones.rec_mobilenet_v3,MobileNetV3
scale: 0.5
model_name: large
Head:
function: ppocr.modeling.heads.rec_attention_head,AttentionPredict
encoder_type: rnn
SeqRNN:
hidden_size: 96
Attention:
decoder_size: 96
word_vector_dim: 96
Loss:
function: ppocr.modeling.losses.rec_attention_loss,AttentionLoss
Optimizer:
function: ppocr.optimizer,AdamDecay
base_lr: 0.001
beta1: 0.9
beta2: 0.999
Global:
algorithm: STARNet
use_gpu: true
epoch_num: 72
log_smooth_window: 20
print_batch_step: 10
save_model_dir: output
save_epoch_step: 3
eval_batch_step: 2000
train_batch_size_per_card: 256
test_batch_size_per_card: 256
image_shape: [3, 32, 100]
max_text_length: 25
character_type: en
loss_type: ctc
reader_yml: ./configs/rec/rec_benchmark_reader.yml
pretrain_weights:
Architecture:
function: ppocr.modeling.architectures.rec_model,RecModel
TPS:
function: ppocr.modeling.stns.tps,TPS
num_fiducial: 20
loc_lr: 0.1
model_name: small
Backbone:
function: ppocr.modeling.backbones.rec_mobilenet_v3,MobileNetV3
scale: 0.5
model_name: large
Head:
function: ppocr.modeling.heads.rec_ctc_head,CTCPredict
encoder_type: rnn
SeqRNN:
hidden_size: 96
Loss:
function: ppocr.modeling.losses.rec_ctc_loss,CTCLoss
Optimizer:
function: ppocr.optimizer,AdamDecay
base_lr: 0.001
beta1: 0.9
beta2: 0.999
Global:
algorithm: CRNN
use_gpu: true
epoch_num: 72
log_smooth_window: 20
print_batch_step: 10
save_model_dir: output
save_epoch_step: 3
eval_batch_step: 2000
train_batch_size_per_card: 256
test_batch_size_per_card: 256
image_shape: [3, 32, 100]
max_text_length: 25
character_type: en
loss_type: ctc
reader_yml: ./configs/rec/rec_benchmark_reader.yml
pretrain_weights:
Architecture:
function: ppocr.modeling.architectures.rec_model,RecModel
Backbone:
function: ppocr.modeling.backbones.rec_resnet_vd,ResNet
layers: 34
Head:
function: ppocr.modeling.heads.rec_ctc_head,CTCPredict
encoder_type: rnn
SeqRNN:
hidden_size: 256
Loss:
function: ppocr.modeling.losses.rec_ctc_loss,CTCLoss
Optimizer:
function: ppocr.optimizer,AdamDecay
base_lr: 0.001
beta1: 0.9
beta2: 0.999
Global:
algorithm: Rosetta
use_gpu: true
epoch_num: 72
log_smooth_window: 20
print_batch_step: 10
save_model_dir: output
save_epoch_step: 3
eval_batch_step: 2000
train_batch_size_per_card: 256
test_batch_size_per_card: 256
image_shape: [3, 32, 100]
max_text_length: 25
character_type: en
loss_type: ctc
reader_yml: ./configs/rec/rec_benchmark_reader.yml
pretrain_weights:
Architecture:
function: ppocr.modeling.architectures.rec_model,RecModel
Backbone:
function: ppocr.modeling.backbones.rec_resnet_vd,ResNet
layers: 34
Head:
function: ppocr.modeling.heads.rec_ctc_head,CTCPredict
encoder_type: reshape
Loss:
function: ppocr.modeling.losses.rec_ctc_loss,CTCLoss
Optimizer:
function: ppocr.optimizer,AdamDecay
base_lr: 0.001
beta1: 0.9
beta2: 0.999
Global:
algorithm: RARE
use_gpu: true
epoch_num: 72
log_smooth_window: 20
print_batch_step: 10
save_model_dir: output
save_epoch_step: 3
eval_batch_step: 2000
train_batch_size_per_card: 256
test_batch_size_per_card: 256
image_shape: [3, 32, 100]
max_text_length: 25
character_type: en
loss_type: attention
reader_yml: ./configs/rec/rec_benchmark_reader.yml
pretrain_weights:
Architecture:
function: ppocr.modeling.architectures.rec_model,RecModel
TPS:
function: ppocr.modeling.stns.tps,TPS
num_fiducial: 20
loc_lr: 0.1
model_name: large
Backbone:
function: ppocr.modeling.backbones.rec_resnet_vd,ResNet
layers: 34
Head:
function: ppocr.modeling.heads.rec_attention_head,AttentionPredict
encoder_type: rnn
SeqRNN:
hidden_size: 256
Attention:
decoder_size: 128
word_vector_dim: 128
Loss:
function: ppocr.modeling.losses.rec_attention_loss,AttentionLoss
Optimizer:
function: ppocr.optimizer,AdamDecay
base_lr: 0.001
beta1: 0.9
beta2: 0.999
Global:
algorithm: STARNet
use_gpu: true
epoch_num: 72
log_smooth_window: 20
print_batch_step: 10
save_model_dir: output
save_epoch_step: 3
eval_batch_step: 2000
train_batch_size_per_card: 256
test_batch_size_per_card: 256
image_shape: [3, 32, 100]
max_text_length: 25
character_type: en
loss_type: ctc
reader_yml: ./configs/rec/rec_benchmark_reader.yml
pretrain_weights:
Architecture:
function: ppocr.modeling.architectures.rec_model,RecModel
TPS:
function: ppocr.modeling.stns.tps,TPS
num_fiducial: 20
loc_lr: 0.1
model_name: large
Backbone:
function: ppocr.modeling.backbones.rec_resnet_vd,ResNet
layers: 34
Head:
function: ppocr.modeling.heads.rec_ctc_head,CTCPredict
encoder_type: rnn
SeqRNN:
hidden_size: 256
Loss:
function: ppocr.modeling.losses.rec_ctc_loss,CTCLoss
Optimizer:
function: ppocr.optimizer,AdamDecay
base_lr: 0.001
beta1: 0.9
beta2: 0.999
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -*- coding:utf-8 -*-
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import numpy as np
import random
import cv2
import math
import imgaug
import imgaug.augmenters as iaa
def AugmentData(data):
img = data['image']
shape = img.shape
aug = iaa.Sequential(
[iaa.Fliplr(0.5), iaa.Affine(rotate=(-10, 10)), iaa.Resize(
(0.5, 3))]).to_deterministic()
def may_augment_annotation(aug, data, shape):
if aug is None:
return data
line_polys = []
for poly in data['polys']:
new_poly = may_augment_poly(aug, shape, poly)
line_polys.append(new_poly)
data['polys'] = np.array(line_polys)
return data
def may_augment_poly(aug, img_shape, poly):
keypoints = [imgaug.Keypoint(p[0], p[1]) for p in poly]
keypoints = aug.augment_keypoints(
[imgaug.KeypointsOnImage(
keypoints, shape=img_shape)])[0].keypoints
poly = [(p.x, p.y) for p in keypoints]
return poly
img_aug = aug.augment_image(img)
data['image'] = img_aug
data = may_augment_annotation(aug, data, shape)
return data
#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import os
import math
import random
import functools
import numpy as np
import cv2
import string
from ppocr.utils.utility import initial_logger
logger = initial_logger()
from ppocr.utils.utility import create_module
import time
class TrainReader(object):
def __init__(self, params):
self.num_workers = params['num_workers']
self.label_file_path = params['label_file_path']
self.batch_size = params['train_batch_size_per_card']
assert 'process_function' in params,\
"absence process_function in Reader"
self.process = create_module(params['process_function'])(params)
def __call__(self, process_id):
def sample_iter_reader():
with open(self.label_file_path, "rb") as fin:
label_infor_list = fin.readlines()
img_num = len(label_infor_list)
img_id_list = list(range(img_num))
random.shuffle(img_id_list)
for img_id in range(process_id, img_num, self.num_workers):
label_infor = label_infor_list[img_id_list[img_id]]
outs = self.process(label_infor)
if outs is None:
continue
yield outs
def batch_iter_reader():
batch_outs = []
for outs in sample_iter_reader():
batch_outs.append(outs)
if len(batch_outs) == self.batch_size:
yield batch_outs
batch_outs = []
if len(batch_outs) != 0:
yield batch_outs
return batch_iter_reader
class EvalTestReader(object):
def __init__(self, params):
self.params = params
assert 'process_function' in params,\
"absence process_function in EvalTestReader"
def __call__(self, mode):
process_function = create_module(self.params['process_function'])(
self.params)
batch_size = self.params['test_batch_size_per_card']
flag_test_single_img = False
if mode == "test":
single_img_path = self.params['single_img_path']
if single_img_path is not None:
flag_test_single_img = True
img_list = []
if flag_test_single_img:
img_list.append([single_img_path, single_img_path])
else:
img_set_dir = self.params['img_set_dir']
img_name_list_path = self.params['label_file_path']
with open(img_name_list_path, "rb") as fin:
lines = fin.readlines()
for line in lines:
img_name = line.decode().strip("\n").split("\t")[0]
img_path = img_set_dir + "/" + img_name
img_list.append([img_path, img_name])
def batch_iter_reader():
batch_outs = []
for img_path, img_name in img_list:
img = cv2.imread(img_path)
if img is None:
logger.info("load image error:" + img_path)
continue
outs = process_function(img)
outs.append(img_name)
batch_outs.append(outs)
if len(batch_outs) == batch_size:
yield batch_outs
batch_outs = []
if len(batch_outs) != 0:
yield batch_outs
return batch_iter_reader
#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import math
import cv2
import numpy as np
import json
import sys
from .data_augment import AugmentData
from .random_crop_data import RandomCropData
from .make_shrink_map import MakeShrinkMap
from .make_border_map import MakeBorderMap
class DBProcessTrain(object):
def __init__(self, params):
self.img_set_dir = params['img_set_dir']
self.image_shape = params['image_shape']
def order_points_clockwise(self, pts):
rect = np.zeros((4, 2), dtype="float32")
s = pts.sum(axis=1)
rect[0] = pts[np.argmin(s)]
rect[2] = pts[np.argmax(s)]
diff = np.diff(pts, axis=1)
rect[1] = pts[np.argmin(diff)]
rect[3] = pts[np.argmax(diff)]
return rect
def make_data_dict(self, imgvalue, entry):
boxes = []
texts = []
ignores = []
for rect in entry:
points = rect['points']
transcription = rect['transcription']
try:
box = self.order_points_clockwise(
np.array(points).reshape(-1, 2))
if cv2.contourArea(box) > 0:
boxes.append(box)
texts.append(transcription)
ignores.append(transcription in ['*', '###'])
except:
print('load label failed!')
data = {
'image': imgvalue,
'shape': [imgvalue.shape[0], imgvalue.shape[1]],
'polys': np.array(boxes),
'texts': texts,
'ignore_tags': ignores,
}
return data
def NormalizeImage(self, data):
im = data['image']
img_mean = [0.485, 0.456, 0.406]
img_std = [0.229, 0.224, 0.225]
im = im.astype(np.float32, copy=False)
im = im / 255
im -= img_mean
im /= img_std
channel_swap = (2, 0, 1)
im = im.transpose(channel_swap)
data['image'] = im
return data
def FilterKeys(self, data):
filter_keys = ['polys', 'texts', 'ignore_tags', 'shape']
for key in filter_keys:
if key in data:
del data[key]
return data
def convert_label_infor(self, label_infor):
label_infor = label_infor.decode()
label_infor = label_infor.encode('utf-8').decode('utf-8-sig')
substr = label_infor.strip("\n").split("\t")
img_path = self.img_set_dir + substr[0]
label = json.loads(substr[1])
return img_path, label
def __call__(self, label_infor):
img_path, gt_label = self.convert_label_infor(label_infor)
imgvalue = cv2.imread(img_path)
if imgvalue is None:
return None
data = self.make_data_dict(imgvalue, gt_label)
data = AugmentData(data)
data = RandomCropData(data, self.image_shape[1:])
data = MakeShrinkMap(data)
data = MakeBorderMap(data)
data = self.NormalizeImage(data)
data = self.FilterKeys(data)
return data['image'], data['shrink_map'], data['shrink_mask'], data[
'threshold_map'], data['threshold_mask']
class DBProcessTest(object):
def __init__(self, params):
super(DBProcessTest, self).__init__()
self.resize_type = 0
if 'det_image_shape' in params:
self.image_shape = params['det_image_shape']
# print(self.image_shape)
self.resize_type = 1
if 'max_side_len' in params:
self.max_side_len = params['max_side_len']
else:
self.max_side_len = 2400
def resize_image_type0(self, im):
"""
resize image to a size multiple of 32 which is required by the network
:param im: the resized image
:param max_side_len: limit of max image size to avoid out of memory in gpu
:return: the resized image and the resize ratio
"""
max_side_len = self.max_side_len
h, w, _ = im.shape
resize_w = w
resize_h = h
# limit the max side
if max(resize_h, resize_w) > max_side_len:
if resize_h > resize_w:
ratio = float(max_side_len) / resize_h
else:
ratio = float(max_side_len) / resize_w
else:
ratio = 1.
resize_h = int(resize_h * ratio)
resize_w = int(resize_w * ratio)
if resize_h % 32 == 0:
resize_h = resize_h
else:
resize_h = (resize_h // 32 + 1) * 32
if resize_w % 32 == 0:
resize_w = resize_w
else:
resize_w = (resize_w // 32 + 1) * 32
try:
if int(resize_w) <= 0 or int(resize_h) <= 0:
return None, (None, None)
im = cv2.resize(im, (int(resize_w), int(resize_h)))
except:
print(im.shape, resize_w, resize_h)
sys.exit(0)
ratio_h = resize_h / float(h)
ratio_w = resize_w / float(w)
return im, (ratio_h, ratio_w)
def resize_image_type1(self, im):
resize_h, resize_w = self.image_shape
ori_h, ori_w = im.shape[:2] # (h, w, c)
im = cv2.resize(im, (int(resize_w), int(resize_h)))
ratio_h = float(resize_h) / ori_h
ratio_w = float(resize_w) / ori_w
return im, (ratio_h, ratio_w)
def normalize(self, im):
img_mean = [0.485, 0.456, 0.406]
img_std = [0.229, 0.224, 0.225]
im = im.astype(np.float32, copy=False)
im = im / 255
im -= img_mean
im /= img_std
channel_swap = (2, 0, 1)
im = im.transpose(channel_swap)
return im
def __call__(self, im):
if self.resize_type == 0:
im, (ratio_h, ratio_w) = self.resize_image_type0(im)
else:
im, (ratio_h, ratio_w) = self.resize_image_type1(im)
im = self.normalize(im)
im = im[np.newaxis, :]
return [im, (ratio_h, ratio_w)]
此差异已折叠。
# -*- coding:utf-8 -*-
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import numpy as np
import cv2
np.seterr(divide='ignore', invalid='ignore')
import pyclipper
from shapely.geometry import Polygon
import sys
import warnings
warnings.simplefilter("ignore")
def draw_border_map(polygon, canvas, mask, shrink_ratio):
polygon = np.array(polygon)
assert polygon.ndim == 2
assert polygon.shape[1] == 2
polygon_shape = Polygon(polygon)
if polygon_shape.area <= 0:
return
distance = polygon_shape.area * (
1 - np.power(shrink_ratio, 2)) / polygon_shape.length
subject = [tuple(l) for l in polygon]
padding = pyclipper.PyclipperOffset()
padding.AddPath(subject, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
padded_polygon = np.array(padding.Execute(distance)[0])
cv2.fillPoly(mask, [padded_polygon.astype(np.int32)], 1.0)
xmin = padded_polygon[:, 0].min()
xmax = padded_polygon[:, 0].max()
ymin = padded_polygon[:, 1].min()
ymax = padded_polygon[:, 1].max()
width = xmax - xmin + 1
height = ymax - ymin + 1
polygon[:, 0] = polygon[:, 0] - xmin
polygon[:, 1] = polygon[:, 1] - ymin
xs = np.broadcast_to(
np.linspace(
0, width - 1, num=width).reshape(1, width), (height, width))
ys = np.broadcast_to(
np.linspace(
0, height - 1, num=height).reshape(height, 1), (height, width))
distance_map = np.zeros((polygon.shape[0], height, width), dtype=np.float32)
for i in range(polygon.shape[0]):
j = (i + 1) % polygon.shape[0]
absolute_distance = _distance(xs, ys, polygon[i], polygon[j])
distance_map[i] = np.clip(absolute_distance / distance, 0, 1)
distance_map = distance_map.min(axis=0)
xmin_valid = min(max(0, xmin), canvas.shape[1] - 1)
xmax_valid = min(max(0, xmax), canvas.shape[1] - 1)
ymin_valid = min(max(0, ymin), canvas.shape[0] - 1)
ymax_valid = min(max(0, ymax), canvas.shape[0] - 1)
canvas[ymin_valid:ymax_valid + 1, xmin_valid:xmax_valid + 1] = np.fmax(
1 - distance_map[ymin_valid - ymin:ymax_valid - ymax + height,
xmin_valid - xmin:xmax_valid - xmax + width],
canvas[ymin_valid:ymax_valid + 1, xmin_valid:xmax_valid + 1])
def _distance(xs, ys, point_1, point_2):
'''
compute the distance from point to a line
ys: coordinates in the first axis
xs: coordinates in the second axis
point_1, point_2: (x, y), the end of the line
'''
height, width = xs.shape[:2]
square_distance_1 = np.square(xs - point_1[0]) + np.square(ys - point_1[1])
square_distance_2 = np.square(xs - point_2[0]) + np.square(ys - point_2[1])
square_distance = np.square(point_1[0] - point_2[0]) + np.square(point_1[
1] - point_2[1])
cosin = (square_distance - square_distance_1 - square_distance_2) / (
2 * np.sqrt(square_distance_1 * square_distance_2))
square_sin = 1 - np.square(cosin)
square_sin = np.nan_to_num(square_sin)
result = np.sqrt(square_distance_1 * square_distance_2 * square_sin /
square_distance)
result[cosin <
0] = np.sqrt(np.fmin(square_distance_1, square_distance_2))[cosin <
0]
# self.extend_line(point_1, point_2, result)
return result
def extend_line(point_1, point_2, result, shrink_ratio):
ex_point_1 = (
int(
round(point_1[0] + (point_1[0] - point_2[0]) * (1 + shrink_ratio))),
int(
round(point_1[1] + (point_1[1] - point_2[1]) * (1 + shrink_ratio))))
cv2.line(
result,
tuple(ex_point_1),
tuple(point_1),
4096.0,
1,
lineType=cv2.LINE_AA,
shift=0)
ex_point_2 = (
int(
round(point_2[0] + (point_2[0] - point_1[0]) * (1 + shrink_ratio))),
int(
round(point_2[1] + (point_2[1] - point_1[1]) * (1 + shrink_ratio))))
cv2.line(
result,
tuple(ex_point_2),
tuple(point_2),
4096.0,
1,
lineType=cv2.LINE_AA,
shift=0)
return ex_point_1, ex_point_2
def MakeBorderMap(data):
shrink_ratio = 0.4
thresh_min = 0.3
thresh_max = 0.7
im = data['image']
text_polys = data['polys']
ignore_tags = data['ignore_tags']
canvas = np.zeros(im.shape[:2], dtype=np.float32)
mask = np.zeros(im.shape[:2], dtype=np.float32)
for i in range(len(text_polys)):
if ignore_tags[i]:
continue
draw_border_map(
text_polys[i], canvas, mask=mask, shrink_ratio=shrink_ratio)
canvas = canvas * (thresh_max - thresh_min) + thresh_min
data['threshold_map'] = canvas
data['threshold_mask'] = mask
return data
# -*- coding:utf-8 -*-
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import numpy as np
import cv2
from shapely.geometry import Polygon
import pyclipper
def validate_polygons(polygons, ignore_tags, h, w):
'''
polygons (numpy.array, required): of shape (num_instances, num_points, 2)
'''
if len(polygons) == 0:
return polygons, ignore_tags
assert len(polygons) == len(ignore_tags)
for polygon in polygons:
polygon[:, 0] = np.clip(polygon[:, 0], 0, w - 1)
polygon[:, 1] = np.clip(polygon[:, 1], 0, h - 1)
for i in range(len(polygons)):
area = polygon_area(polygons[i])
if abs(area) < 1:
ignore_tags[i] = True
if area > 0:
polygons[i] = polygons[i][::-1, :]
return polygons, ignore_tags
def polygon_area(polygon):
edge = 0
for i in range(polygon.shape[0]):
next_index = (i + 1) % polygon.shape[0]
edge += (polygon[next_index, 0] - polygon[i, 0]) * (
polygon[next_index, 1] - polygon[i, 1])
return edge / 2.
def MakeShrinkMap(data):
min_text_size = 8
shrink_ratio = 0.4
image = data['image']
text_polys = data['polys']
ignore_tags = data['ignore_tags']
h, w = image.shape[:2]
text_polys, ignore_tags = validate_polygons(text_polys, ignore_tags, h, w)
gt = np.zeros((h, w), dtype=np.float32)
# gt = np.zeros((1, h, w), dtype=np.float32)
mask = np.ones((h, w), dtype=np.float32)
for i in range(len(text_polys)):
polygon = text_polys[i]
height = max(polygon[:, 1]) - min(polygon[:, 1])
width = max(polygon[:, 0]) - min(polygon[:, 0])
# height = min(np.linalg.norm(polygon[0] - polygon[3]),
# np.linalg.norm(polygon[1] - polygon[2]))
# width = min(np.linalg.norm(polygon[0] - polygon[1]),
# np.linalg.norm(polygon[2] - polygon[3]))
if ignore_tags[i] or min(height, width) < min_text_size:
cv2.fillPoly(mask, polygon.astype(np.int32)[np.newaxis, :, :], 0)
ignore_tags[i] = True
else:
polygon_shape = Polygon(polygon)
distance = polygon_shape.area * (
1 - np.power(shrink_ratio, 2)) / polygon_shape.length
subject = [tuple(l) for l in text_polys[i]]
padding = pyclipper.PyclipperOffset()
padding.AddPath(subject, pyclipper.JT_ROUND,
pyclipper.ET_CLOSEDPOLYGON)
shrinked = padding.Execute(-distance)
if shrinked == []:
cv2.fillPoly(mask,
polygon.astype(np.int32)[np.newaxis, :, :], 0)
ignore_tags[i] = True
continue
shrinked = np.array(shrinked[0]).reshape(-1, 2)
cv2.fillPoly(gt, [shrinked.astype(np.int32)], 1)
# cv2.fillPoly(gt[0], [shrinked.astype(np.int32)], 1)
data['shrink_map'] = gt
data['shrink_mask'] = mask
return data
# -*- coding:utf-8 -*-
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import numpy as np
import cv2
import random
def is_poly_in_rect(poly, x, y, w, h):
poly = np.array(poly)
if poly[:, 0].min() < x or poly[:, 0].max() > x + w:
return False
if poly[:, 1].min() < y or poly[:, 1].max() > y + h:
return False
return True
def is_poly_outside_rect(poly, x, y, w, h):
poly = np.array(poly)
if poly[:, 0].max() < x or poly[:, 0].min() > x + w:
return True
if poly[:, 1].max() < y or poly[:, 1].min() > y + h:
return True
return False
def split_regions(axis):
regions = []
min_axis = 0
for i in range(1, axis.shape[0]):
if axis[i] != axis[i - 1] + 1:
region = axis[min_axis:i]
min_axis = i
regions.append(region)
return regions
def random_select(axis, max_size):
xx = np.random.choice(axis, size=2)
xmin = np.min(xx)
xmax = np.max(xx)
xmin = np.clip(xmin, 0, max_size - 1)
xmax = np.clip(xmax, 0, max_size - 1)
return xmin, xmax
def region_wise_random_select(regions, max_size):
selected_index = list(np.random.choice(len(regions), 2))
selected_values = []
for index in selected_index:
axis = regions[index]
xx = int(np.random.choice(axis, size=1))
selected_values.append(xx)
xmin = min(selected_values)
xmax = max(selected_values)
return xmin, xmax
def crop_area(im, text_polys, min_crop_side_ratio, max_tries):
h, w, _ = im.shape
h_array = np.zeros(h, dtype=np.int32)
w_array = np.zeros(w, dtype=np.int32)
for points in text_polys:
points = np.round(points, decimals=0).astype(np.int32)
minx = np.min(points[:, 0])
maxx = np.max(points[:, 0])
w_array[minx:maxx] = 1
miny = np.min(points[:, 1])
maxy = np.max(points[:, 1])
h_array[miny:maxy] = 1
# ensure the cropped area not across a text
h_axis = np.where(h_array == 0)[0]
w_axis = np.where(w_array == 0)[0]
if len(h_axis) == 0 or len(w_axis) == 0:
return 0, 0, w, h
h_regions = split_regions(h_axis)
w_regions = split_regions(w_axis)
for i in range(max_tries):
if len(w_regions) > 1:
xmin, xmax = region_wise_random_select(w_regions, w)
else:
xmin, xmax = random_select(w_axis, w)
if len(h_regions) > 1:
ymin, ymax = region_wise_random_select(h_regions, h)
else:
ymin, ymax = random_select(h_axis, h)
if xmax - xmin < min_crop_side_ratio * w or ymax - ymin < min_crop_side_ratio * h:
# area too small
continue
num_poly_in_rect = 0
for poly in text_polys:
if not is_poly_outside_rect(poly, xmin, ymin, xmax - xmin,
ymax - ymin):
num_poly_in_rect += 1
break
if num_poly_in_rect > 0:
return xmin, ymin, xmax - xmin, ymax - ymin
return 0, 0, w, h
def RandomCropData(data, size):
max_tries = 10
min_crop_side_ratio = 0.1
require_original_image = False
keep_ratio = True
im = data['image']
text_polys = data['polys']
ignore_tags = data['ignore_tags']
texts = data['texts']
all_care_polys = [
text_polys[i] for i, tag in enumerate(ignore_tags) if not tag
]
# 计算crop区域
crop_x, crop_y, crop_w, crop_h = crop_area(im, all_care_polys,
min_crop_side_ratio, max_tries)
# crop 图片 保持比例填充
scale_w = size[0] / crop_w
scale_h = size[1] / crop_h
scale = min(scale_w, scale_h)
h = int(crop_h * scale)
w = int(crop_w * scale)
if keep_ratio:
padimg = np.zeros((size[1], size[0], im.shape[2]), im.dtype)
padimg[:h, :w] = cv2.resize(
im[crop_y:crop_y + crop_h, crop_x:crop_x + crop_w], (w, h))
img = padimg
else:
img = cv2.resize(im[crop_y:crop_y + crop_h, crop_x:crop_x + crop_w],
tuple(size))
# crop 文本框
text_polys_crop = []
ignore_tags_crop = []
texts_crop = []
for poly, text, tag in zip(text_polys, texts, ignore_tags):
poly = ((poly - (crop_x, crop_y)) * scale).tolist()
if not is_poly_outside_rect(poly, 0, 0, w, h):
text_polys_crop.append(poly)
ignore_tags_crop.append(tag)
texts_crop.append(text)
data['image'] = img
data['polys'] = np.array(text_polys_crop)
data['ignore_tags'] = ignore_tags_crop
data['texts'] = texts_crop
return data
#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import os
import random
import numpy as np
import paddle
from ppocr.utils.utility import create_module
from copy import deepcopy
from .rec.img_tools import process_image
import cv2
import sys
import signal
# handle terminate reader process, do not print stack frame
def _reader_quit(signum, frame):
print("Reader process exit.")
sys.exit()
def _term_group(sig_num, frame):
print('pid {} terminated, terminate group '
'{}...'.format(os.getpid(), os.getpgrp()))
os.killpg(os.getpgid(os.getpid()), signal.SIGKILL)
signal.signal(signal.SIGTERM, _reader_quit)
signal.signal(signal.SIGINT, _term_group)
def reader_main(config=None, mode=None):
"""Create a reader for trainning
Args:
settings: arguments
Returns:
train reader
"""
assert mode in ["train", "eval", "test"],\
"Nonsupport mode:{}".format(mode)
global_params = config['Global']
if mode == "train":
params = deepcopy(config['TrainReader'])
elif mode == "eval":
params = deepcopy(config['EvalReader'])
else:
params = deepcopy(config['TestReader'])
params['mode'] = mode
params.update(global_params)
reader_function = params['reader_function']
function = create_module(reader_function)(params)
if mode == "train":
readers = []
num_workers = params['num_workers']
for process_id in range(num_workers):
readers.append(function(process_id))
return paddle.reader.multiprocess_reader(readers, False)
else:
return function(mode)
def test_reader(image_shape, img_path):
img = cv2.imread(img_path)
norm_img = process_image(img, image_shape)
return norm_img
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import os
import math
import random
import numpy as np
import cv2
import string
import lmdb
from ppocr.utils.utility import initial_logger
logger = initial_logger()
from .img_tools import process_image, get_img_data
class LMDBReader(object):
def __init__(self, params):
if params['mode'] != 'train':
self.num_workers = 1
else:
self.num_workers = params['num_workers']
self.lmdb_sets_dir = params['lmdb_sets_dir']
self.char_ops = params['char_ops']
self.image_shape = params['image_shape']
self.loss_type = params['loss_type']
self.max_text_length = params['max_text_length']
self.mode = params['mode']
if params['mode'] == 'train':
self.batch_size = params['train_batch_size_per_card']
else:
self.batch_size = params['test_batch_size_per_card']
def load_hierarchical_lmdb_dataset(self):
lmdb_sets = {}
dataset_idx = 0
for dirpath, dirnames, filenames in os.walk(self.lmdb_sets_dir + '/'):
if not dirnames:
env = lmdb.open(
dirpath,
max_readers=32,
readonly=True,
lock=False,
readahead=False,
meminit=False)
txn = env.begin(write=False)
num_samples = int(txn.get('num-samples'.encode()))
lmdb_sets[dataset_idx] = {"dirpath":dirpath, "env":env, \
"txn":txn, "num_samples":num_samples}
dataset_idx += 1
return lmdb_sets
def print_lmdb_sets_info(self, lmdb_sets):
lmdb_info_strs = []
for dataset_idx in range(len(lmdb_sets)):
tmp_str = " %s:%d," % (lmdb_sets[dataset_idx]['dirpath'],
lmdb_sets[dataset_idx]['num_samples'])
lmdb_info_strs.append(tmp_str)
lmdb_info_strs = ''.join(lmdb_info_strs)
logger.info("DataSummary:" + lmdb_info_strs)
return
def close_lmdb_dataset(self, lmdb_sets):
for dataset_idx in lmdb_sets:
lmdb_sets[dataset_idx]['env'].close()
return
def get_lmdb_sample_info(self, txn, index):
label_key = 'label-%09d'.encode() % index
label = txn.get(label_key)
if label is None:
return None
label = label.decode('utf-8')
img_key = 'image-%09d'.encode() % index
imgbuf = txn.get(img_key)
img = get_img_data(imgbuf)
if img is None:
return None
return img, label
def __call__(self, process_id):
if self.mode != 'train':
process_id = 0
def sample_iter_reader():
lmdb_sets = self.load_hierarchical_lmdb_dataset()
if process_id == 0:
self.print_lmdb_sets_info(lmdb_sets)
cur_index_sets = [1 + process_id] * len(lmdb_sets)
while True:
finish_read_num = 0
for dataset_idx in range(len(lmdb_sets)):
cur_index = cur_index_sets[dataset_idx]
if cur_index > lmdb_sets[dataset_idx]['num_samples']:
finish_read_num += 1
else:
sample_info = self.get_lmdb_sample_info(
lmdb_sets[dataset_idx]['txn'], cur_index)
cur_index_sets[dataset_idx] += self.num_workers
if sample_info is None:
continue
img, label = sample_info
outs = process_image(img, self.image_shape, label,
self.char_ops, self.loss_type,
self.max_text_length)
if outs is None:
continue
yield outs
if finish_read_num == len(lmdb_sets):
break
self.close_lmdb_dataset(lmdb_sets)
def batch_iter_reader():
batch_outs = []
for outs in sample_iter_reader():
batch_outs.append(outs)
if len(batch_outs) == self.batch_size:
yield batch_outs
batch_outs = []
if len(batch_outs) != 0:
yield batch_outs
return batch_iter_reader
class SimpleReader(object):
def __init__(self, params):
if params['mode'] != 'train':
self.num_workers = 1
else:
self.num_workers = params['num_workers']
self.img_set_dir = params['img_set_dir']
self.label_file_path = params['label_file_path']
self.char_ops = params['char_ops']
self.image_shape = params['image_shape']
self.loss_type = params['loss_type']
self.max_text_length = params['max_text_length']
self.mode = params['mode']
if params['mode'] == 'train':
self.batch_size = params['train_batch_size_per_card']
elif params['mode'] == 'eval':
self.batch_size = params['test_batch_size_per_card']
else:
self.batch_size = 1
self.infer_img = params['infer_img']
def __call__(self, process_id):
if self.mode != 'train':
process_id = 0
def sample_iter_reader():
if self.mode == 'test':
print("infer_img:", self.infer_img)
img = cv2.imread(self.infer_img)
norm_img = process_image(img, self.image_shape)
yield norm_img
with open(self.label_file_path, "rb") as fin:
label_infor_list = fin.readlines()
img_num = len(label_infor_list)
img_id_list = list(range(img_num))
random.shuffle(img_id_list)
for img_id in range(process_id, img_num, self.num_workers):
label_infor = label_infor_list[img_id_list[img_id]]
substr = label_infor.decode('utf-8').strip("\n").split("\t")
img_path = self.img_set_dir + "/" + substr[0]
img = cv2.imread(img_path)
if img is None:
continue
label = substr[1]
outs = process_image(img, self.image_shape, label,
self.char_ops, self.loss_type,
self.max_text_length)
if outs is None:
continue
yield outs
def batch_iter_reader():
batch_outs = []
for outs in sample_iter_reader():
batch_outs.append(outs)
if len(batch_outs) == self.batch_size:
yield batch_outs
batch_outs = []
if len(batch_outs) != 0:
yield batch_outs
return batch_iter_reader
#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import math
import cv2
import numpy as np
def get_bounding_box_rect(pos):
left = min(pos[0])
right = max(pos[0])
top = min(pos[1])
bottom = max(pos[1])
return [left, top, right, bottom]
def resize_norm_img(img, image_shape):
imgC, imgH, imgW = image_shape
h = img.shape[0]
w = img.shape[1]
ratio = w / float(h)
if math.ceil(imgH * ratio) > imgW:
resized_w = imgW
else:
resized_w = int(math.ceil(imgH * ratio))
resized_image = cv2.resize(img, (resized_w, imgH))
resized_image = resized_image.astype('float32')
if image_shape[0] == 1:
resized_image = resized_image / 255
resized_image = resized_image[np.newaxis, :]
else:
resized_image = resized_image.transpose((2, 0, 1)) / 255
resized_image -= 0.5
resized_image /= 0.5
padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
padding_im[:, :, 0:resized_w] = resized_image
return padding_im
def get_img_data(value):
"""get_img_data"""
if not value:
return None
imgdata = np.frombuffer(value, dtype='uint8')
if imgdata is None:
return None
imgori = cv2.imdecode(imgdata, 1)
if imgori is None:
return None
return imgori
def process_image(img,
image_shape,
label=None,
char_ops=None,
loss_type=None,
max_text_length=None):
norm_img = resize_norm_img(img, image_shape)
norm_img = norm_img[np.newaxis, :]
if label is not None:
char_num = char_ops.get_char_num()
text = char_ops.encode(label)
if len(text) == 0 or len(text) > max_text_length:
return None
else:
if loss_type == "ctc":
text = text.reshape(-1, 1)
return (norm_img, text)
elif loss_type == "attention":
beg_flag_idx = char_ops.get_beg_end_flag_idx("beg")
end_flag_idx = char_ops.get_beg_end_flag_idx("end")
beg_text = np.append(beg_flag_idx, text)
end_text = np.append(text, end_flag_idx)
beg_text = beg_text.reshape(-1, 1)
end_text = end_text.reshape(-1, 1)
return (norm_img, beg_text, end_text)
else:
assert False, "Unsupport loss_type %s in process_image"\
% loss_type
return (norm_img)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from paddle import fluid
from ppocr.utils.utility import create_module
from ppocr.utils.utility import initial_logger
logger = initial_logger()
from copy import deepcopy
class DetModel(object):
def __init__(self, params):
"""
Detection module for OCR text detection.
args:
params (dict): the super parameters for detection module.
"""
global_params = params['Global']
self.algorithm = global_params['algorithm']
backbone_params = deepcopy(params["Backbone"])
backbone_params.update(global_params)
self.backbone = create_module(backbone_params['function'])\
(params=backbone_params)
head_params = deepcopy(params["Head"])
head_params.update(global_params)
self.head = create_module(head_params['function'])\
(params=head_params)
loss_params = deepcopy(params["Loss"])
loss_params.update(global_params)
self.loss = create_module(loss_params['function'])\
(params=loss_params)
self.image_shape = global_params['image_shape']
def create_feed(self, mode):
"""
create Dataloader feeds
args:
mode (str): 'train' for training or else for evaluation
return: (image, corresponding label, dataloader)
"""
image_shape = deepcopy(self.image_shape)
image = fluid.layers.data(
name='image', shape=image_shape, dtype='float32')
if mode == "train":
if self.algorithm == "EAST":
score = fluid.layers.data(
name='score', shape=[1, 128, 128], dtype='float32')
geo = fluid.layers.data(
name='geo', shape=[9, 128, 128], dtype='float32')
mask = fluid.layers.data(
name='mask', shape=[1, 128, 128], dtype='float32')
feed_list = [image, score, geo, mask]
labels = {'score': score, 'geo': geo, 'mask': mask}
elif self.algorithm == "DB":
shrink_map = fluid.layers.data(
name='shrink_map', shape=image_shape[1:], dtype='float32')
shrink_mask = fluid.layers.data(
name='shrink_mask', shape=image_shape[1:], dtype='float32')
threshold_map = fluid.layers.data(
name='threshold_map',
shape=image_shape[1:],
dtype='float32')
threshold_mask = fluid.layers.data(
name='threshold_mask',
shape=image_shape[1:],
dtype='float32')
feed_list=[image, shrink_map, shrink_mask,\
threshold_map, threshold_mask]
labels = {'shrink_map':shrink_map,\
'shrink_mask':shrink_mask,\
'threshold_map':threshold_map,\
'threshold_mask':threshold_mask}
loader = fluid.io.DataLoader.from_generator(
feed_list=feed_list,
capacity=64,
use_double_buffer=True,
iterable=False)
else:
labels = None
loader = None
return image, labels, loader
def __call__(self, mode):
"""
run forward of defined module
args:
mode (str): 'train' for training; 'export' for inference,
others for evaluation]
"""
image, labels, loader = self.create_feed(mode)
conv_feas = self.backbone(image)
predicts = self.head(conv_feas)
if mode == "train":
losses = self.loss(predicts, labels)
return loader, losses
elif mode == "export":
return [image, predicts]
else:
return loader, predicts
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from paddle import fluid
from ppocr.utils.utility import create_module
from ppocr.utils.utility import initial_logger
logger = initial_logger()
from copy import deepcopy
class RecModel(object):
def __init__(self, params):
super(RecModel, self).__init__()
global_params = params['Global']
char_num = global_params['char_ops'].get_char_num()
global_params['char_num'] = char_num
if "TPS" in params:
tps_params = deepcopy(params["TPS"])
tps_params.update(global_params)
self.tps = create_module(tps_params['function'])\
(params=tps_params)
else:
self.tps = None
backbone_params = deepcopy(params["Backbone"])
backbone_params.update(global_params)
self.backbone = create_module(backbone_params['function'])\
(params=backbone_params)
head_params = deepcopy(params["Head"])
head_params.update(global_params)
self.head = create_module(head_params['function'])\
(params=head_params)
loss_params = deepcopy(params["Loss"])
loss_params.update(global_params)
self.loss = create_module(loss_params['function'])\
(params=loss_params)
self.loss_type = global_params['loss_type']
self.image_shape = global_params['image_shape']
self.max_text_length = global_params['max_text_length']
def create_feed(self, mode):
image_shape = deepcopy(self.image_shape)
image_shape.insert(0, -1)
image = fluid.data(name='image', shape=image_shape, dtype='float32')
if mode == "train":
if self.loss_type == "attention":
label_in = fluid.data(
name='label_in',
shape=[None, 1],
dtype='int32',
lod_level=1)
label_out = fluid.data(
name='label_out',
shape=[None, 1],
dtype='int32',
lod_level=1)
feed_list = [image, label_in, label_out]
labels = {'label_in': label_in, 'label_out': label_out}
else:
label = fluid.data(
name='label', shape=[None, 1], dtype='int32', lod_level=1)
feed_list = [image, label]
labels = {'label': label}
loader = fluid.io.DataLoader.from_generator(
feed_list=feed_list,
capacity=64,
use_double_buffer=True,
iterable=False)
else:
labels = None
loader = None
return image, labels, loader
def __call__(self, mode):
image, labels, loader = self.create_feed(mode)
if self.tps is None:
inputs = image
else:
inputs = self.tps(image)
conv_feas = self.backbone(inputs)
predicts = self.head(conv_feas, labels, mode)
decoded_out = predicts['decoded_out']
if mode == "train":
loss = self.loss(predicts, labels)
if self.loss_type == "attention":
label = labels['label_out']
else:
label = labels['label']
outputs = {'total_loss':loss, 'decoded_out':\
decoded_out, 'label':label}
return loader, outputs
elif mode == "export":
return [image, {'decoded_out': decoded_out}]
else:
return loader, {'decoded_out': decoded_out}
#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle.fluid as fluid
from paddle.fluid.initializer import MSRA
from paddle.fluid.param_attr import ParamAttr
__all__ = ['MobileNetV3']
class MobileNetV3():
def __init__(self, params):
"""
the MobilenetV3 backbone network for detection module.
Args:
params(dict): the super parameters for build network
"""
self.scale = params['scale']
model_name = params['model_name']
self.inplanes = 16
if model_name == "large":
self.cfg = [
# k, exp, c, se, nl, s,
[3, 16, 16, False, 'relu', 1],
[3, 64, 24, False, 'relu', 2],
[3, 72, 24, False, 'relu', 1],
[5, 72, 40, True, 'relu', 2],
[5, 120, 40, True, 'relu', 1],
[5, 120, 40, True, 'relu', 1],
[3, 240, 80, False, 'hard_swish', 2],
[3, 200, 80, False, 'hard_swish', 1],
[3, 184, 80, False, 'hard_swish', 1],
[3, 184, 80, False, 'hard_swish', 1],
[3, 480, 112, True, 'hard_swish', 1],
[3, 672, 112, True, 'hard_swish', 1],
[5, 672, 160, True, 'hard_swish', 2],
[5, 960, 160, True, 'hard_swish', 1],
[5, 960, 160, True, 'hard_swish', 1],
]
self.cls_ch_squeeze = 960
self.cls_ch_expand = 1280
elif model_name == "small":
self.cfg = [
# k, exp, c, se, nl, s,
[3, 16, 16, True, 'relu', 2],
[3, 72, 24, False, 'relu', 2],
[3, 88, 24, False, 'relu', 1],
[5, 96, 40, True, 'hard_swish', 2],
[5, 240, 40, True, 'hard_swish', 1],
[5, 240, 40, True, 'hard_swish', 1],
[5, 120, 48, True, 'hard_swish', 1],
[5, 144, 48, True, 'hard_swish', 1],
[5, 288, 96, True, 'hard_swish', 2],
[5, 576, 96, True, 'hard_swish', 1],
[5, 576, 96, True, 'hard_swish', 1],
]
self.cls_ch_squeeze = 576
self.cls_ch_expand = 1280
else:
raise NotImplementedError("mode[" + model_name +
"_model] is not implemented!")
supported_scale = [0.35, 0.5, 0.75, 1.0, 1.25]
assert self.scale in supported_scale, \
"supported scale are {} but input scale is {}".format(supported_scale, self.scale)
def __call__(self, input):
scale = self.scale
inplanes = self.inplanes
cfg = self.cfg
cls_ch_squeeze = self.cls_ch_squeeze
cls_ch_expand = self.cls_ch_expand
#conv1
conv = self.conv_bn_layer(
input,
filter_size=3,
num_filters=self.make_divisible(inplanes * scale),
stride=2,
padding=1,
num_groups=1,
if_act=True,
act='hard_swish',
name='conv1')
i = 0
inplanes = self.make_divisible(inplanes * scale)
outs = []
for layer_cfg in cfg:
if layer_cfg[5] == 2 and i > 2:
outs.append(conv)
conv = self.residual_unit(
input=conv,
num_in_filter=inplanes,
num_mid_filter=self.make_divisible(scale * layer_cfg[1]),
num_out_filter=self.make_divisible(scale * layer_cfg[2]),
act=layer_cfg[4],
stride=layer_cfg[5],
filter_size=layer_cfg[0],
use_se=layer_cfg[3],
name='conv' + str(i + 2))
inplanes = self.make_divisible(scale * layer_cfg[2])
i += 1
conv = self.conv_bn_layer(
input=conv,
filter_size=1,
num_filters=self.make_divisible(scale * cls_ch_squeeze),
stride=1,
padding=0,
num_groups=1,
if_act=True,
act='hard_swish',
name='conv_last')
outs.append(conv)
return outs
def conv_bn_layer(self,
input,
filter_size,
num_filters,
stride,
padding,
num_groups=1,
if_act=True,
act=None,
name=None,
use_cudnn=True,
res_last_bn_init=False):
conv = fluid.layers.conv2d(
input=input,
num_filters=num_filters,
filter_size=filter_size,
stride=stride,
padding=padding,
groups=num_groups,
act=None,
use_cudnn=use_cudnn,
param_attr=ParamAttr(name=name + '_weights'),
bias_attr=False)
bn_name = name + '_bn'
bn = fluid.layers.batch_norm(
input=conv,
param_attr=ParamAttr(
name=bn_name + "_scale",
regularizer=fluid.regularizer.L2DecayRegularizer(
regularization_coeff=0.0)),
bias_attr=ParamAttr(
name=bn_name + "_offset",
regularizer=fluid.regularizer.L2DecayRegularizer(
regularization_coeff=0.0)),
moving_mean_name=bn_name + '_mean',
moving_variance_name=bn_name + '_variance')
if if_act:
if act == 'relu':
bn = fluid.layers.relu(bn)
elif act == 'hard_swish':
bn = fluid.layers.hard_swish(bn)
return bn
def make_divisible(self, v, divisor=8, min_value=None):
if min_value is None:
min_value = divisor
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
if new_v < 0.9 * v:
new_v += divisor
return new_v
def se_block(self, input, num_out_filter, ratio=4, name=None):
num_mid_filter = num_out_filter // ratio
pool = fluid.layers.pool2d(
input=input, pool_type='avg', global_pooling=True, use_cudnn=False)
conv1 = fluid.layers.conv2d(
input=pool,
filter_size=1,
num_filters=num_mid_filter,
act='relu',
param_attr=ParamAttr(name=name + '_1_weights'),
bias_attr=ParamAttr(name=name + '_1_offset'))
conv2 = fluid.layers.conv2d(
input=conv1,
filter_size=1,
num_filters=num_out_filter,
act='hard_sigmoid',
param_attr=ParamAttr(name=name + '_2_weights'),
bias_attr=ParamAttr(name=name + '_2_offset'))
scale = fluid.layers.elementwise_mul(x=input, y=conv2, axis=0)
return scale
def residual_unit(self,
input,
num_in_filter,
num_mid_filter,
num_out_filter,
stride,
filter_size,
act=None,
use_se=False,
name=None):
conv0 = self.conv_bn_layer(
input=input,
filter_size=1,
num_filters=num_mid_filter,
stride=1,
padding=0,
if_act=True,
act=act,
name=name + '_expand')
conv1 = self.conv_bn_layer(
input=conv0,
filter_size=filter_size,
num_filters=num_mid_filter,
stride=stride,
padding=int((filter_size - 1) // 2),
if_act=True,
act=act,
num_groups=num_mid_filter,
use_cudnn=False,
name=name + '_depthwise')
if use_se:
conv1 = self.se_block(
input=conv1, num_out_filter=num_mid_filter, name=name + '_se')
conv2 = self.conv_bn_layer(
input=conv1,
filter_size=1,
num_filters=num_out_filter,
stride=1,
padding=0,
if_act=False,
name=name + '_linear',
res_last_bn_init=True)
if num_in_filter != num_out_filter or stride != 1:
return conv2
else:
return fluid.layers.elementwise_add(x=input, y=conv2, act=None)
#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle.fluid as fluid
from paddle.fluid.param_attr import ParamAttr
__all__ = ["ResNet"]
class ResNet(object):
def __init__(self, params):
"""
the Resnet backbone network for detection module.
Args:
params(dict): the super parameters for network build
"""
self.layers = params['layers']
supported_layers = [18, 34, 50, 101, 152]
assert self.layers in supported_layers, \
"supported layers are {} but input layer is {}".format(supported_layers, self.layers)
self.is_3x3 = True
def __call__(self, input):
layers = self.layers
is_3x3 = self.is_3x3
if layers == 18:
depth = [2, 2, 2, 2]
elif layers == 34 or layers == 50:
depth = [3, 4, 6, 3]
elif layers == 101:
depth = [3, 4, 23, 3]
elif layers == 152:
depth = [3, 8, 36, 3]
elif layers == 200:
depth = [3, 12, 48, 3]
num_filters = [64, 128, 256, 512]
outs = []
if is_3x3 == False:
conv = self.conv_bn_layer(
input=input,
num_filters=64,
filter_size=7,
stride=2,
act='relu')
else:
conv = self.conv_bn_layer(
input=input,
num_filters=32,
filter_size=3,
stride=2,
act='relu',
name='conv1_1')
conv = self.conv_bn_layer(
input=conv,
num_filters=32,
filter_size=3,
stride=1,
act='relu',
name='conv1_2')
conv = self.conv_bn_layer(
input=conv,
num_filters=64,
filter_size=3,
stride=1,
act='relu',
name='conv1_3')
conv = fluid.layers.pool2d(
input=conv,
pool_size=3,
pool_stride=2,
pool_padding=1,
pool_type='max')
if layers >= 50:
for block in range(len(depth)):
for i in range(depth[block]):
if layers in [101, 152, 200] and block == 2:
if i == 0:
conv_name = "res" + str(block + 2) + "a"
else:
conv_name = "res" + str(block + 2) + "b" + str(i)
else:
conv_name = "res" + str(block + 2) + chr(97 + i)
conv = self.bottleneck_block(
input=conv,
num_filters=num_filters[block],
stride=2 if i == 0 and block != 0 else 1,
if_first=block == i == 0,
name=conv_name)
outs.append(conv)
else:
for block in range(len(depth)):
for i in range(depth[block]):
conv_name = "res" + str(block + 2) + chr(97 + i)
conv = self.basic_block(
input=conv,
num_filters=num_filters[block],
stride=2 if i == 0 and block != 0 else 1,
if_first=block == i == 0,
name=conv_name)
outs.append(conv)
return outs
def conv_bn_layer(self,
input,
num_filters,
filter_size,
stride=1,
groups=1,
act=None,
name=None):
conv = fluid.layers.conv2d(
input=input,
num_filters=num_filters,
filter_size=filter_size,
stride=stride,
padding=(filter_size - 1) // 2,
groups=groups,
act=None,
param_attr=ParamAttr(name=name + "_weights"),
bias_attr=False)
if name == "conv1":
bn_name = "bn_" + name
else:
bn_name = "bn" + name[3:]
return fluid.layers.batch_norm(
input=conv,
act=act,
param_attr=ParamAttr(name=bn_name + '_scale'),
bias_attr=ParamAttr(bn_name + '_offset'),
moving_mean_name=bn_name + '_mean',
moving_variance_name=bn_name + '_variance')
def conv_bn_layer_new(self,
input,
num_filters,
filter_size,
stride=1,
groups=1,
act=None,
name=None):
pool = fluid.layers.pool2d(
input=input,
pool_size=2,
pool_stride=2,
pool_padding=0,
pool_type='avg',
ceil_mode=True)
conv = fluid.layers.conv2d(
input=pool,
num_filters=num_filters,
filter_size=filter_size,
stride=1,
padding=(filter_size - 1) // 2,
groups=groups,
act=None,
param_attr=ParamAttr(name=name + "_weights"),
bias_attr=False)
if name == "conv1":
bn_name = "bn_" + name
else:
bn_name = "bn" + name[3:]
return fluid.layers.batch_norm(
input=conv,
act=act,
param_attr=ParamAttr(name=bn_name + '_scale'),
bias_attr=ParamAttr(bn_name + '_offset'),
moving_mean_name=bn_name + '_mean',
moving_variance_name=bn_name + '_variance')
def shortcut(self, input, ch_out, stride, name, if_first=False):
ch_in = input.shape[1]
if ch_in != ch_out or stride != 1:
if if_first:
return self.conv_bn_layer(input, ch_out, 1, stride, name=name)
else:
return self.conv_bn_layer_new(
input, ch_out, 1, stride, name=name)
elif if_first:
return self.conv_bn_layer(input, ch_out, 1, stride, name=name)
else:
return input
def bottleneck_block(self, input, num_filters, stride, name, if_first):
conv0 = self.conv_bn_layer(
input=input,
num_filters=num_filters,
filter_size=1,
act='relu',
name=name + "_branch2a")
conv1 = self.conv_bn_layer(
input=conv0,
num_filters=num_filters,
filter_size=3,
stride=stride,
act='relu',
name=name + "_branch2b")
conv2 = self.conv_bn_layer(
input=conv1,
num_filters=num_filters * 4,
filter_size=1,
act=None,
name=name + "_branch2c")
short = self.shortcut(
input,
num_filters * 4,
stride,
if_first=if_first,
name=name + "_branch1")
return fluid.layers.elementwise_add(x=short, y=conv2, act='relu')
def basic_block(self, input, num_filters, stride, name, if_first):
conv0 = self.conv_bn_layer(
input=input,
num_filters=num_filters,
filter_size=3,
act='relu',
stride=stride,
name=name + "_branch2a")
conv1 = self.conv_bn_layer(
input=conv0,
num_filters=num_filters,
filter_size=3,
act=None,
name=name + "_branch2b")
short = self.shortcut(
input,
num_filters,
stride,
if_first=if_first,
name=name + "_branch1")
return fluid.layers.elementwise_add(x=short, y=conv1, act='relu')
#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle.fluid as fluid
from paddle.fluid.initializer import MSRA
from paddle.fluid.param_attr import ParamAttr
__all__ = [
'MobileNetV3', 'MobileNetV3_small_x0_35', 'MobileNetV3_small_x0_5',
'MobileNetV3_small_x0_75', 'MobileNetV3_small_x1_0',
'MobileNetV3_small_x1_25', 'MobileNetV3_large_x0_35',
'MobileNetV3_large_x0_5', 'MobileNetV3_large_x0_75',
'MobileNetV3_large_x1_0', 'MobileNetV3_large_x1_25'
]
class MobileNetV3():
def __init__(self, params):
self.scale = params['scale']
model_name = params['model_name']
self.inplanes = 16
if model_name == "large":
self.cfg = [
# k, exp, c, se, nl, s,
[3, 16, 16, False, 'relu', 1],
[3, 64, 24, False, 'relu', (2, 1)],
[3, 72, 24, False, 'relu', 1],
[5, 72, 40, True, 'relu', (2, 1)],
[5, 120, 40, True, 'relu', 1],
[5, 120, 40, True, 'relu', 1],
[3, 240, 80, False, 'hard_swish', 1],
[3, 200, 80, False, 'hard_swish', 1],
[3, 184, 80, False, 'hard_swish', 1],
[3, 184, 80, False, 'hard_swish', 1],
[3, 480, 112, True, 'hard_swish', 1],
[3, 672, 112, True, 'hard_swish', 1],
[5, 672, 160, True, 'hard_swish', (2, 1)],
[5, 960, 160, True, 'hard_swish', 1],
[5, 960, 160, True, 'hard_swish', 1],
]
self.cls_ch_squeeze = 960
self.cls_ch_expand = 1280
elif model_name == "small":
self.cfg = [
# k, exp, c, se, nl, s,
[3, 16, 16, True, 'relu', (2, 1)],
[3, 72, 24, False, 'relu', (2, 1)],
[3, 88, 24, False, 'relu', 1],
[5, 96, 40, True, 'hard_swish', (2, 1)],
[5, 240, 40, True, 'hard_swish', 1],
[5, 240, 40, True, 'hard_swish', 1],
[5, 120, 48, True, 'hard_swish', 1],
[5, 144, 48, True, 'hard_swish', 1],
[5, 288, 96, True, 'hard_swish', (2, 1)],
[5, 576, 96, True, 'hard_swish', 1],
[5, 576, 96, True, 'hard_swish', 1],
]
self.cls_ch_squeeze = 576
self.cls_ch_expand = 1280
else:
raise NotImplementedError("mode[" + model_name +
"_model] is not implemented!")
supported_scale = [0.35, 0.5, 0.75, 1.0, 1.25]
assert self.scale in supported_scale, \
"supported scale are {} but input scale is {}".format(supported_scale, scale)
def __call__(self, input):
scale = self.scale
inplanes = self.inplanes
cfg = self.cfg
cls_ch_squeeze = self.cls_ch_squeeze
cls_ch_expand = self.cls_ch_expand
#conv1
conv = self.conv_bn_layer(
input,
filter_size=3,
num_filters=self.make_divisible(inplanes * scale),
stride=2,
padding=1,
num_groups=1,
if_act=True,
act='hard_swish',
name='conv1')
i = 0
inplanes = self.make_divisible(inplanes * scale)
for layer_cfg in cfg:
conv = self.residual_unit(
input=conv,
num_in_filter=inplanes,
num_mid_filter=self.make_divisible(scale * layer_cfg[1]),
num_out_filter=self.make_divisible(scale * layer_cfg[2]),
act=layer_cfg[4],
stride=layer_cfg[5],
filter_size=layer_cfg[0],
use_se=layer_cfg[3],
name='conv' + str(i + 2))
inplanes = self.make_divisible(scale * layer_cfg[2])
i += 1
conv = self.conv_bn_layer(
input=conv,
filter_size=1,
num_filters=self.make_divisible(scale * cls_ch_squeeze),
stride=1,
padding=0,
num_groups=1,
if_act=True,
act='hard_swish',
name='conv_last')
conv = fluid.layers.pool2d(
input=conv,
pool_size=2,
pool_stride=2,
pool_padding=0,
pool_type='max')
return conv
def conv_bn_layer(self,
input,
filter_size,
num_filters,
stride,
padding,
num_groups=1,
if_act=True,
act=None,
name=None,
use_cudnn=True,
res_last_bn_init=False):
conv = fluid.layers.conv2d(
input=input,
num_filters=num_filters,
filter_size=filter_size,
stride=stride,
padding=padding,
groups=num_groups,
act=None,
use_cudnn=use_cudnn,
param_attr=ParamAttr(name=name + '_weights'),
bias_attr=False)
bn_name = name + '_bn'
bn = fluid.layers.batch_norm(
input=conv,
param_attr=ParamAttr(
name=bn_name + "_scale",
regularizer=fluid.regularizer.L2DecayRegularizer(
regularization_coeff=0.0)),
bias_attr=ParamAttr(
name=bn_name + "_offset",
regularizer=fluid.regularizer.L2DecayRegularizer(
regularization_coeff=0.0)),
moving_mean_name=bn_name + '_mean',
moving_variance_name=bn_name + '_variance')
if if_act:
if act == 'relu':
bn = fluid.layers.relu(bn)
elif act == 'hard_swish':
bn = fluid.layers.hard_swish(bn)
return bn
def make_divisible(self, v, divisor=8, min_value=None):
if min_value is None:
min_value = divisor
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
if new_v < 0.9 * v:
new_v += divisor
return new_v
def se_block(self, input, num_out_filter, ratio=4, name=None):
num_mid_filter = num_out_filter // ratio
pool = fluid.layers.pool2d(
input=input, pool_type='avg', global_pooling=True, use_cudnn=False)
conv1 = fluid.layers.conv2d(
input=pool,
filter_size=1,
num_filters=num_mid_filter,
act='relu',
param_attr=ParamAttr(name=name + '_1_weights'),
bias_attr=ParamAttr(name=name + '_1_offset'))
conv2 = fluid.layers.conv2d(
input=conv1,
filter_size=1,
num_filters=num_out_filter,
act='hard_sigmoid',
param_attr=ParamAttr(name=name + '_2_weights'),
bias_attr=ParamAttr(name=name + '_2_offset'))
scale = fluid.layers.elementwise_mul(x=input, y=conv2, axis=0)
return scale
def residual_unit(self,
input,
num_in_filter,
num_mid_filter,
num_out_filter,
stride,
filter_size,
act=None,
use_se=False,
name=None):
conv0 = self.conv_bn_layer(
input=input,
filter_size=1,
num_filters=num_mid_filter,
stride=1,
padding=0,
if_act=True,
act=act,
name=name + '_expand')
conv1 = self.conv_bn_layer(
input=conv0,
filter_size=filter_size,
num_filters=num_mid_filter,
stride=stride,
padding=int((filter_size - 1) // 2),
if_act=True,
act=act,
num_groups=num_mid_filter,
use_cudnn=False,
name=name + '_depthwise')
if use_se:
conv1 = self.se_block(
input=conv1, num_out_filter=num_mid_filter, name=name + '_se')
conv2 = self.conv_bn_layer(
input=conv1,
filter_size=1,
num_filters=num_out_filter,
stride=1,
padding=0,
if_act=False,
name=name + '_linear',
res_last_bn_init=True)
if num_in_filter != num_out_filter or stride != 1:
return conv2
else:
return fluid.layers.elementwise_add(x=input, y=conv2, act=None)
#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import paddle
import paddle.fluid as fluid
from paddle.fluid.param_attr import ParamAttr
__all__ = [
"ResNet", "ResNet18_vd", "ResNet34_vd", "ResNet50_vd", "ResNet101_vd",
"ResNet152_vd", "ResNet200_vd"
]
class ResNet():
def __init__(self, params):
self.layers = params['layers']
self.is_3x3 = True
supported_layers = [18, 34, 50, 101, 152, 200]
assert self.layers in supported_layers, \
"supported layers are {} but input layer is {}".format(supported_layers, self.layers)
def __call__(self, input):
is_3x3 = self.is_3x3
layers = self.layers
if layers == 18:
depth = [2, 2, 2, 2]
elif layers == 34 or layers == 50:
depth = [3, 4, 6, 3]
elif layers == 101:
depth = [3, 4, 23, 3]
elif layers == 152:
depth = [3, 8, 36, 3]
elif layers == 200:
depth = [3, 12, 48, 3]
num_filters = [64, 128, 256, 512]
if is_3x3 == False:
conv = self.conv_bn_layer(
input=input,
num_filters=64,
filter_size=7,
stride=1,
act='relu')
else:
conv = self.conv_bn_layer(
input=input,
num_filters=32,
filter_size=3,
stride=1,
act='relu',
name='conv1_1')
conv = self.conv_bn_layer(
input=conv,
num_filters=32,
filter_size=3,
stride=1,
act='relu',
name='conv1_2')
conv = self.conv_bn_layer(
input=conv,
num_filters=64,
filter_size=3,
stride=1,
act='relu',
name='conv1_3')
conv = fluid.layers.pool2d(
input=conv,
pool_size=3,
pool_stride=2,
pool_padding=1,
pool_type='max')
if layers >= 50:
for block in range(len(depth)):
for i in range(depth[block]):
if layers in [101, 152, 200] and block == 2:
if i == 0:
conv_name = "res" + str(block + 2) + "a"
else:
conv_name = "res" + str(block + 2) + "b" + str(i)
else:
conv_name = "res" + str(block + 2) + chr(97 + i)
if i == 0 and block != 0:
stride = (2, 1)
else:
stride = (1, 1)
conv = self.bottleneck_block(
input=conv,
num_filters=num_filters[block],
stride=stride,
if_first=block == i == 0,
name=conv_name)
else:
for block in range(len(depth)):
for i in range(depth[block]):
conv_name = "res" + str(block + 2) + chr(97 + i)
if i == 0 and block != 0:
stride = (2, 1)
else:
stride = (1, 1)
conv = self.basic_block(
input=conv,
num_filters=num_filters[block],
stride=stride,
if_first=block == i == 0,
name=conv_name)
conv = fluid.layers.pool2d(
input=conv,
pool_size=2,
pool_stride=2,
pool_padding=0,
pool_type='max')
return conv
def conv_bn_layer(self,
input,
num_filters,
filter_size,
stride=1,
groups=1,
act=None,
name=None):
conv = fluid.layers.conv2d(
input=input,
num_filters=num_filters,
filter_size=filter_size,
stride=stride,
padding=(filter_size - 1) // 2,
groups=groups,
act=None,
param_attr=ParamAttr(name=name + "_weights"),
bias_attr=False)
if name == "conv1":
bn_name = "bn_" + name
else:
bn_name = "bn" + name[3:]
return fluid.layers.batch_norm(
input=conv,
act=act,
param_attr=ParamAttr(name=bn_name + '_scale'),
bias_attr=ParamAttr(bn_name + '_offset'),
moving_mean_name=bn_name + '_mean',
moving_variance_name=bn_name + '_variance')
def conv_bn_layer_new(self,
input,
num_filters,
filter_size,
stride=1,
groups=1,
act=None,
name=None):
pool = fluid.layers.pool2d(
input=input,
pool_size=stride,
pool_stride=stride,
pool_padding=0,
pool_type='avg',
ceil_mode=True)
conv = fluid.layers.conv2d(
input=pool,
num_filters=num_filters,
filter_size=filter_size,
stride=1,
padding=(filter_size - 1) // 2,
groups=groups,
act=None,
param_attr=ParamAttr(name=name + "_weights"),
bias_attr=False)
if name == "conv1":
bn_name = "bn_" + name
else:
bn_name = "bn" + name[3:]
return fluid.layers.batch_norm(
input=conv,
act=act,
param_attr=ParamAttr(name=bn_name + '_scale'),
bias_attr=ParamAttr(bn_name + '_offset'),
moving_mean_name=bn_name + '_mean',
moving_variance_name=bn_name + '_variance')
def shortcut(self, input, ch_out, stride, name, if_first=False):
ch_in = input.shape[1]
if ch_in != ch_out or stride[0] != 1:
if if_first:
return self.conv_bn_layer(input, ch_out, 1, stride, name=name)
else:
return self.conv_bn_layer_new(
input, ch_out, 1, stride, name=name)
elif if_first:
return self.conv_bn_layer(input, ch_out, 1, stride, name=name)
else:
return input
def bottleneck_block(self, input, num_filters, stride, name, if_first):
conv0 = self.conv_bn_layer(
input=input,
num_filters=num_filters,
filter_size=1,
act='relu',
name=name + "_branch2a")
conv1 = self.conv_bn_layer(
input=conv0,
num_filters=num_filters,
filter_size=3,
stride=stride,
act='relu',
name=name + "_branch2b")
conv2 = self.conv_bn_layer(
input=conv1,
num_filters=num_filters * 4,
filter_size=1,
act=None,
name=name + "_branch2c")
short = self.shortcut(
input,
num_filters * 4,
stride,
if_first=if_first,
name=name + "_branch1")
return fluid.layers.elementwise_add(x=short, y=conv2, act='relu')
def basic_block(self, input, num_filters, stride, name, if_first):
conv0 = self.conv_bn_layer(
input=input,
num_filters=num_filters,
filter_size=3,
act='relu',
stride=stride,
name=name + "_branch2a")
conv1 = self.conv_bn_layer(
input=conv0,
num_filters=num_filters,
filter_size=3,
act=None,
name=name + "_branch2b")
short = self.shortcut(
input,
num_filters,
stride,
if_first=if_first,
name=name + "_branch1")
return fluid.layers.elementwise_add(x=short, y=conv1, act='relu')
#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
import paddle.fluid as fluid
from paddle.fluid.param_attr import ParamAttr
import math
def get_para_bias_attr(l2_decay, k, name):
regularizer = fluid.regularizer.L2Decay(l2_decay)
stdv = 1.0 / math.sqrt(k * 1.0)
initializer = fluid.initializer.Uniform(-stdv, stdv)
para_attr = fluid.ParamAttr(
regularizer=regularizer, initializer=initializer, name=name + "_w_attr")
bias_attr = fluid.ParamAttr(
regularizer=regularizer, initializer=initializer, name=name + "_b_attr")
return [para_attr, bias_attr]
def conv_bn_layer(input,
num_filters,
filter_size,
stride=1,
groups=1,
act=None,
name=None):
conv = fluid.layers.conv2d(
input=input,
num_filters=num_filters,
filter_size=filter_size,
stride=stride,
padding=(filter_size - 1) // 2,
groups=groups,
act=None,
param_attr=ParamAttr(name=name + "_weights"),
bias_attr=False,
name=name + '.conv2d')
bn_name = "bn_" + name
return fluid.layers.batch_norm(
input=conv,
act=act,
name=bn_name + '.output',
param_attr=ParamAttr(name=bn_name + '_scale'),
bias_attr=ParamAttr(bn_name + '_offset'),
moving_mean_name=bn_name + '_mean',
moving_variance_name=bn_name + '_variance')
def deconv_bn_layer(input,
num_filters,
filter_size=4,
stride=2,
act='relu',
name=None):
deconv = fluid.layers.conv2d_transpose(
input=input,
num_filters=num_filters,
filter_size=filter_size,
stride=stride,
padding=1,
act=None,
param_attr=ParamAttr(name=name + "_weights"),
bias_attr=False,
name=name + '.deconv2d')
bn_name = "bn_" + name
return fluid.layers.batch_norm(
input=deconv,
act=act,
name=bn_name + '.output',
param_attr=ParamAttr(name=bn_name + '_scale'),
bias_attr=ParamAttr(bn_name + '_offset'),
moving_mean_name=bn_name + '_mean',
moving_variance_name=bn_name + '_variance')
def create_tmp_var(program, name, dtype, shape, lod_level=0):
return program.current_block().create_var(
name=name, dtype=dtype, shape=shape, lod_level=lod_level)
#copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import paddle.fluid as fluid
class DBHead(object):
"""
Differentiable Binarization (DB) for text detection:
see https://arxiv.org/abs/1911.08947
args:
params(dict): super parameters for build DB network
"""
def __init__(self, params):
self.k = params['k']
self.inner_channels = params['inner_channels']
self.C, self.H, self.W = params['image_shape']
print(self.C, self.H, self.W)
def binarize(self, x):
conv1 = fluid.layers.conv2d(
input=x,
num_filters=self.inner_channels // 4,
filter_size=3,
padding=1,
param_attr=fluid.initializer.MSRAInitializer(uniform=False),
bias_attr=False)
conv_bn1 = fluid.layers.batch_norm(
input=conv1,
param_attr=fluid.initializer.ConstantInitializer(value=1.0),
bias_attr=fluid.initializer.ConstantInitializer(value=1e-4),
act="relu")
conv2 = fluid.layers.conv2d_transpose(
input=conv_bn1,
num_filters=self.inner_channels // 4,
filter_size=2,
stride=2,
param_attr=fluid.initializer.MSRAInitializer(uniform=False),
bias_attr=self._get_bias_attr(0.0004, conv_bn1.shape[1], "conv2"),
act=None)
conv_bn2 = fluid.layers.batch_norm(
input=conv2,
param_attr=fluid.initializer.ConstantInitializer(value=1.0),
bias_attr=fluid.initializer.ConstantInitializer(value=1e-4),
act="relu")
conv3 = fluid.layers.conv2d_transpose(
input=conv_bn2,
num_filters=1,
filter_size=2,
stride=2,
param_attr=fluid.initializer.MSRAInitializer(uniform=False),
bias_attr=self._get_bias_attr(0.0004, conv_bn2.shape[1], "conv3"),
act=None)
out = fluid.layers.sigmoid(conv3)
return out
def thresh(self, x):
conv1 = fluid.layers.conv2d(
input=x,
num_filters=self.inner_channels // 4,
filter_size=3,
padding=1,
param_attr=fluid.initializer.MSRAInitializer(uniform=False),
bias_attr=False)
conv_bn1 = fluid.layers.batch_norm(
input=conv1,
param_attr=fluid.initializer.ConstantInitializer(value=1.0),
bias_attr=fluid.initializer.ConstantInitializer(value=1e-4),
act="relu")
conv2 = fluid.layers.conv2d_transpose(
input=conv_bn1,
num_filters=self.inner_channels // 4,
filter_size=2,
stride=2,
param_attr=fluid.initializer.MSRAInitializer(uniform=False),
bias_attr=self._get_bias_attr(0.0004, conv_bn1.shape[1], "conv2"),
act=None)
conv_bn2 = fluid.layers.batch_norm(
input=conv2,
param_attr=fluid.initializer.ConstantInitializer(value=1.0),
bias_attr=fluid.initializer.ConstantInitializer(value=1e-4),
act="relu")
conv3 = fluid.layers.conv2d_transpose(
input=conv_bn2,
num_filters=1,
filter_size=2,
stride=2,
param_attr=fluid.initializer.MSRAInitializer(uniform=False),
bias_attr=self._get_bias_attr(0.0004, conv_bn2.shape[1], "conv3"),
act=None)
out = fluid.layers.sigmoid(conv3)
return out
def _get_bias_attr(self, l2_decay, k, name, gradient_clip=None):
regularizer = fluid.regularizer.L2Decay(l2_decay)
stdv = 1.0 / math.sqrt(k * 1.0)
initializer = fluid.initializer.Uniform(-stdv, stdv)
bias_attr = fluid.ParamAttr(
regularizer=regularizer,
gradient_clip=gradient_clip,
initializer=initializer,
name=name + "_b_attr")
return bias_attr
def step_function(self, x, y):
return fluid.layers.reciprocal(1 + fluid.layers.exp(-self.k * (x - y)))
def __call__(self, conv_features, mode="train"):
c2, c3, c4, c5 = conv_features
param_attr = fluid.initializer.MSRAInitializer(uniform=False)
in5 = fluid.layers.conv2d(
input=c5,
num_filters=self.inner_channels,
filter_size=1,
param_attr=param_attr,
bias_attr=False)
in4 = fluid.layers.conv2d(
input=c4,
num_filters=self.inner_channels,
filter_size=1,
param_attr=param_attr,
bias_attr=False)
in3 = fluid.layers.conv2d(
input=c3,
num_filters=self.inner_channels,
filter_size=1,
param_attr=param_attr,
bias_attr=False)
in2 = fluid.layers.conv2d(
input=c2,
num_filters=self.inner_channels,
filter_size=1,
param_attr=param_attr,
bias_attr=False)
out4 = fluid.layers.elementwise_add(
x=fluid.layers.resize_nearest(
input=in5, scale=2), y=in4) # 1/16
out3 = fluid.layers.elementwise_add(
x=fluid.layers.resize_nearest(
input=out4, scale=2), y=in3) # 1/8
out2 = fluid.layers.elementwise_add(
x=fluid.layers.resize_nearest(
input=out3, scale=2), y=in2) # 1/4
p5 = fluid.layers.conv2d(
input=in5,
num_filters=self.inner_channels // 4,
filter_size=3,
padding=1,
param_attr=param_attr,
bias_attr=False)
p5 = fluid.layers.resize_nearest(input=p5, scale=8)
p4 = fluid.layers.conv2d(
input=out4,
num_filters=self.inner_channels // 4,
filter_size=3,
padding=1,
param_attr=param_attr,
bias_attr=False)
p4 = fluid.layers.resize_nearest(input=p4, scale=4)
p3 = fluid.layers.conv2d(
input=out3,
num_filters=self.inner_channels // 4,
filter_size=3,
padding=1,
param_attr=param_attr,
bias_attr=False)
p3 = fluid.layers.resize_nearest(input=p3, scale=2)
p2 = fluid.layers.conv2d(
input=out2,
num_filters=self.inner_channels // 4,
filter_size=3,
padding=1,
param_attr=param_attr,
bias_attr=False)
fuse = fluid.layers.concat(input=[p5, p4, p3, p2], axis=1)
shrink_maps = self.binarize(fuse)
if mode != "train":
return shrink_maps
threshold_maps = self.thresh(fuse)
binary_maps = self.step_function(shrink_maps, threshold_maps)
y = fluid.layers.concat(
input=[shrink_maps, threshold_maps, binary_maps], axis=1)
predicts = {}
predicts['maps'] = y
return predicts
#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle.fluid as fluid
from ..common_functions import conv_bn_layer, deconv_bn_layer
class EASTHead(object):
"""
EAST: An Efficient and Accurate Scene Text Detector
see arxiv: https://arxiv.org/abs/1704.03155
args:
params(dict): the super parameters for network build
"""
def __init__(self, params):
self.model_name = params['model_name']
def unet_fusion(self, inputs):
f = inputs[::-1]
if self.model_name == "large":
num_outputs = [128, 128, 128, 128]
else:
num_outputs = [64, 64, 64, 64]
g = [None, None, None, None]
h = [None, None, None, None]
for i in range(4):
if i == 0:
h[i] = f[i]
else:
h[i] = fluid.layers.concat([g[i - 1], f[i]], axis=1)
h[i] = conv_bn_layer(
input=h[i],
num_filters=num_outputs[i],
filter_size=3,
stride=1,
act='relu',
name="unet_h_%d" % (i))
if i <= 2:
#can be replaced with unpool
g[i] = deconv_bn_layer(
input=h[i],
num_filters=num_outputs[i],
name="unet_g_%d" % (i))
else:
g[i] = conv_bn_layer(
input=h[i],
num_filters=num_outputs[i],
filter_size=3,
stride=1,
act='relu',
name="unet_g_%d" % (i))
return g[3]
def detector_header(self, f_common):
if self.model_name == "large":
num_outputs = [128, 64, 1, 8]
else:
num_outputs = [64, 32, 1, 8]
f_det = conv_bn_layer(
input=f_common,
num_filters=num_outputs[0],
filter_size=3,
stride=1,
act='relu',
name="det_head1")
f_det = conv_bn_layer(
input=f_det,
num_filters=num_outputs[1],
filter_size=3,
stride=1,
act='relu',
name="det_head2")
#f_score
f_score = conv_bn_layer(
input=f_det,
num_filters=num_outputs[2],
filter_size=1,
stride=1,
act=None,
name="f_score")
f_score = fluid.layers.sigmoid(f_score)
#f_geo
f_geo = conv_bn_layer(
input=f_det,
num_filters=num_outputs[3],
filter_size=1,
stride=1,
act=None,
name="f_geo")
f_geo = (fluid.layers.sigmoid(f_geo) - 0.5) * 2 * 800
return f_score, f_geo
def __call__(self, inputs):
f_common = self.unet_fusion(inputs)
f_score, f_geo = self.detector_header(f_common)
predicts = {}
predicts['f_score'] = f_score
predicts['f_geo'] = f_geo
return predicts
#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import paddle
import paddle.fluid as fluid
import paddle.fluid.layers as layers
from .rec_seq_encoder import SequenceEncoder
import numpy as np
class AttentionPredict(object):
def __init__(self, params):
super(AttentionPredict, self).__init__()
self.char_num = params['char_num']
self.encoder = SequenceEncoder(params)
self.decoder_size = params['Attention']['decoder_size']
self.word_vector_dim = params['Attention']['word_vector_dim']
self.encoder_type = params['encoder_type']
self.max_length = params['max_text_length']
def simple_attention(self, encoder_vec, encoder_proj, decoder_state,
decoder_size):
decoder_state_proj = layers.fc(input=decoder_state,
size=decoder_size,
bias_attr=False,
name="decoder_state_proj_fc")
decoder_state_expand = layers.sequence_expand(
x=decoder_state_proj, y=encoder_proj)
concated = layers.elementwise_add(encoder_proj, decoder_state_expand)
concated = layers.tanh(x=concated)
attention_weights = layers.fc(input=concated,
size=1,
act=None,
bias_attr=False,
name="attention_weights_fc")
attention_weights = layers.sequence_softmax(input=attention_weights)
weigths_reshape = layers.reshape(x=attention_weights, shape=[-1])
scaled = layers.elementwise_mul(
x=encoder_vec, y=weigths_reshape, axis=0)
context = layers.sequence_pool(input=scaled, pool_type='sum')
return context
def gru_decoder_with_attention(self, target_embedding, encoder_vec,
encoder_proj, decoder_boot, decoder_size,
char_num):
rnn = layers.DynamicRNN()
with rnn.block():
current_word = rnn.step_input(target_embedding)
encoder_vec = rnn.static_input(encoder_vec)
encoder_proj = rnn.static_input(encoder_proj)
hidden_mem = rnn.memory(init=decoder_boot, need_reorder=True)
context = self.simple_attention(encoder_vec, encoder_proj,
hidden_mem, decoder_size)
fc_1 = layers.fc(input=context,
size=decoder_size * 3,
bias_attr=False,
name="rnn_fc1")
fc_2 = layers.fc(input=current_word,
size=decoder_size * 3,
bias_attr=False,
name="rnn_fc2")
decoder_inputs = fc_1 + fc_2
h, _, _ = layers.gru_unit(
input=decoder_inputs, hidden=hidden_mem, size=decoder_size * 3)
rnn.update_memory(hidden_mem, h)
out = layers.fc(input=h,
size=char_num,
bias_attr=True,
act='softmax',
name="rnn_out_fc")
rnn.output(out)
return rnn()
def gru_attention_infer(self, decoder_boot, max_length, char_num,
word_vector_dim, encoded_vector, encoded_proj,
decoder_size):
init_state = decoder_boot
beam_size = 1
array_len = layers.fill_constant(
shape=[1], dtype='int64', value=max_length)
counter = layers.zeros(shape=[1], dtype='int64', force_cpu=True)
# fill the first element with init_state
state_array = layers.create_array('float32')
layers.array_write(init_state, array=state_array, i=counter)
# ids, scores as memory
ids_array = layers.create_array('int64')
scores_array = layers.create_array('float32')
rois_shape = layers.shape(init_state)
batch_size = layers.slice(
rois_shape, axes=[0], starts=[0], ends=[1]) + 1
lod_level = layers.range(
start=0, end=batch_size, step=1, dtype=batch_size.dtype)
init_ids = layers.fill_constant_batch_size_like(
input=init_state, shape=[-1, 1], value=0, dtype='int64')
init_ids = layers.lod_reset(init_ids, lod_level)
init_ids = layers.lod_append(init_ids, lod_level)
init_scores = layers.fill_constant_batch_size_like(
input=init_state, shape=[-1, 1], value=1, dtype='float32')
init_scores = layers.lod_reset(init_scores, init_ids)
layers.array_write(init_ids, array=ids_array, i=counter)
layers.array_write(init_scores, array=scores_array, i=counter)
full_ids = fluid.layers.fill_constant_batch_size_like(
input=init_state, shape=[-1, 1], dtype='int64', value=1)
cond = layers.less_than(x=counter, y=array_len)
while_op = layers.While(cond=cond)
with while_op.block():
pre_ids = layers.array_read(array=ids_array, i=counter)
pre_state = layers.array_read(array=state_array, i=counter)
pre_score = layers.array_read(array=scores_array, i=counter)
pre_ids_emb = layers.embedding(
input=pre_ids,
size=[char_num, word_vector_dim],
dtype='float32')
context = self.simple_attention(encoded_vector, encoded_proj,
pre_state, decoder_size)
# expand the recursive_sequence_lengths of pre_state
# to be the same with pre_score
pre_state_expanded = layers.sequence_expand(pre_state, pre_score)
context_expanded = layers.sequence_expand(context, pre_score)
fc_1 = layers.fc(input=context_expanded,
size=decoder_size * 3,
bias_attr=False,
name="rnn_fc1")
fc_2 = layers.fc(input=pre_ids_emb,
size=decoder_size * 3,
bias_attr=False,
name="rnn_fc2")
decoder_inputs = fc_1 + fc_2
current_state, _, _ = layers.gru_unit(
input=decoder_inputs,
hidden=pre_state_expanded,
size=decoder_size * 3)
current_state_with_lod = layers.lod_reset(
x=current_state, y=pre_score)
# use score to do beam search
current_score = layers.fc(input=current_state_with_lod,
size=char_num,
bias_attr=True,
act='softmax',
name="rnn_out_fc")
topk_scores, topk_indices = layers.topk(current_score, k=beam_size)
new_ids = fluid.layers.concat([full_ids, topk_indices], axis=1)
fluid.layers.assign(new_ids, full_ids)
layers.increment(x=counter, value=1, in_place=True)
# update the memories
layers.array_write(current_state, array=state_array, i=counter)
layers.array_write(topk_indices, array=ids_array, i=counter)
layers.array_write(topk_scores, array=scores_array, i=counter)
# update the break condition:
# up to the max length or all candidates of
# source sentences have ended.
length_cond = layers.less_than(x=counter, y=array_len)
finish_cond = layers.logical_not(layers.is_empty(x=topk_indices))
layers.logical_and(x=length_cond, y=finish_cond, out=cond)
return full_ids
def __call__(self, inputs, labels=None, mode=None):
encoder_features = self.encoder(inputs)
char_num = self.char_num
word_vector_dim = self.word_vector_dim
decoder_size = self.decoder_size
if self.encoder_type == "reshape":
encoder_input = encoder_features
encoded_vector = encoder_features
else:
encoder_input = encoder_features[1]
encoded_vector = layers.concat(encoder_features, axis=1)
encoded_proj = layers.fc(input=encoded_vector,
size=decoder_size,
bias_attr=False,
name="encoded_proj_fc")
backward_first = layers.sequence_pool(
input=encoder_input, pool_type='first')
decoder_boot = layers.fc(input=backward_first,
size=decoder_size,
bias_attr=False,
act="relu",
name='decoder_boot')
if mode == "train":
label_in = labels['label_in']
label_out = labels['label_out']
label_in = layers.cast(x=label_in, dtype='int64')
trg_embedding = layers.embedding(
input=label_in,
size=[char_num, word_vector_dim],
dtype='float32')
predict = self.gru_decoder_with_attention(
trg_embedding, encoded_vector, encoded_proj, decoder_boot,
decoder_size, char_num)
_, decoded_out = layers.topk(input=predict, k=1)
decoded_out = layers.lod_reset(decoded_out, y=label_out)
predicts = {'predict': predict, 'decoded_out': decoded_out}
else:
ids = self.gru_attention_infer(
decoder_boot, self.max_length, char_num, word_vector_dim,
encoded_vector, encoded_proj, decoder_size)
predicts = {'decoded_out': ids}
return predicts
#copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import paddle
import paddle.fluid as fluid
from paddle.fluid.param_attr import ParamAttr
from .rec_seq_encoder import SequenceEncoder
from ..common_functions import get_para_bias_attr
import numpy as np
class CTCPredict(object):
def __init__(self, params):
super(CTCPredict, self).__init__()
self.char_num = params['char_num']
self.encoder = SequenceEncoder(params)
self.encoder_type = params['encoder_type']
def __call__(self, inputs, labels=None, mode=None):
encoder_features = self.encoder(inputs)
if self.encoder_type != "reshape":
encoder_features = fluid.layers.concat(encoder_features, axis=1)
name = "ctc_fc"
para_attr, bias_attr = get_para_bias_attr(
l2_decay=0.0004, k=encoder_features.shape[1], name=name)
predict = fluid.layers.fc(input=encoder_features,
size=self.char_num + 1,
param_attr=para_attr,
bias_attr=bias_attr,
name=name)
decoded_out = fluid.layers.ctc_greedy_decoder(
input=predict, blank=self.char_num)
predicts = {'predict': predict, 'decoded_out': decoded_out}
return predicts
#copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import paddle.fluid as fluid
import paddle.fluid.layers as layers
class EncoderWithReshape(object):
def __init__(self, params):
super(EncoderWithReshape, self).__init__()
def __call__(self, inputs):
sliced_feature = layers.im2sequence(
input=inputs,
stride=[1, 1],
filter_size=[inputs.shape[2], 1],
name="sliced_feature")
return sliced_feature
class EncoderWithRNN(object):
def __init__(self, params):
super(EncoderWithRNN, self).__init__()
self.rnn_hidden_size = params['SeqRNN']['hidden_size']
def __call__(self, inputs):
lstm_list = []
name_prefix = "lstm"
rnn_hidden_size = self.rnn_hidden_size
for no in range(1, 3):
if no == 1:
is_reverse = False
else:
is_reverse = True
name = "%s_st1_fc%d" % (name_prefix, no)
fc = layers.fc(input=inputs,
size=rnn_hidden_size * 4,
param_attr=fluid.ParamAttr(name=name + "_w"),
bias_attr=fluid.ParamAttr(name=name + "_b"),
name=name)
name = "%s_st1_out%d" % (name_prefix, no)
lstm, _ = layers.dynamic_lstm(
input=fc,
size=rnn_hidden_size * 4,
is_reverse=is_reverse,
param_attr=fluid.ParamAttr(name=name + "_w"),
bias_attr=fluid.ParamAttr(name=name + "_b"),
use_peepholes=False)
name = "%s_st2_fc%d" % (name_prefix, no)
fc = layers.fc(input=lstm,
size=rnn_hidden_size * 4,
param_attr=fluid.ParamAttr(name=name + "_w"),
bias_attr=fluid.ParamAttr(name=name + "_b"),
name=name)
name = "%s_st2_out%d" % (name_prefix, no)
lstm, _ = layers.dynamic_lstm(
input=fc,
size=rnn_hidden_size * 4,
is_reverse=is_reverse,
param_attr=fluid.ParamAttr(name=name + "_w"),
bias_attr=fluid.ParamAttr(name=name + "_b"),
use_peepholes=False)
lstm_list.append(lstm)
return lstm_list
class SequenceEncoder(object):
def __init__(self, params):
super(SequenceEncoder, self).__init__()
self.encoder_type = params['encoder_type']
self.encoder_reshape = EncoderWithReshape(params)
if self.encoder_type == "rnn":
self.encoder_rnn = EncoderWithRNN(params)
def __call__(self, inputs):
if self.encoder_type == "reshape":
encoder_features = self.encoder_reshape(inputs)
elif self.encoder_type == "rnn":
inputs = self.encoder_reshape(inputs)
encoder_features = self.encoder_rnn(inputs)
else:
assert False, "Unsupport encoder_type:%s"\
% self.encoder_type
return encoder_features
#copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import paddle.fluid as fluid
def BalanceLoss(pred,
gt,
mask,
balance_loss=True,
main_loss_type="DiceLoss",
negative_ratio=3,
return_origin=False,
eps=1e-6):
"""
The BalanceLoss for Differentiable Binarization text detection
args:
pred (variable): predicted feature maps.
gt (variable): ground truth feature maps.
mask (variable): masked maps.
balance_loss (bool): whether balance loss or not, default is True
main_loss_type (str): can only be one of ['CrossEntropy','DiceLoss',
'Euclidean','BCELoss', 'MaskL1Loss'], default is 'DiceLoss'.
negative_ratio (int|float): float, default is 3.
return_origin (bool): whether return unbalanced loss or not, default is False.
eps (float): default is 1e-6.
return: (variable) balanced loss
"""
positive = gt * mask
negative = (1 - gt) * mask
positive_count = fluid.layers.reduce_sum(positive)
positive_count_int = fluid.layers.cast(positive_count, dtype=np.int32)
negative_count = min(
fluid.layers.reduce_sum(negative), positive_count * negative_ratio)
negative_count_int = fluid.layers.cast(negative_count, dtype=np.int32)
if main_loss_type == "CrossEntropy":
loss = fluid.layers.cross_entropy(input=pred, label=gt, soft_label=True)
loss = fluid.layers.reduce_mean(loss)
elif main_loss_type == "Euclidean":
loss = fluid.layers.square(pred - gt)
loss = fluid.layers.reduce_mean(loss)
elif main_loss_type == "DiceLoss":
loss = DiceLoss(pred, gt, mask)
elif main_loss_type == "BCELoss":
loss = fluid.layers.sigmoid_cross_entropy_with_logits(pred, label=gt)
elif main_loss_type == "MaskL1Loss":
loss = MaskL1Loss(pred, gt, mask)
else:
loss_type = [
'CrossEntropy', 'DiceLoss', 'Euclidean', 'BCELoss', 'MaskL1Loss'
]
raise Exception("main_loss_type in BalanceLoss() can only be one of {}".
format(loss_type))
if not balance_loss:
return loss
positive_loss = positive * loss
negative_loss = negative * loss
negative_loss = fluid.layers.reshape(negative_loss, shape=[-1])
negative_loss, _ = fluid.layers.topk(negative_loss, k=negative_count_int)
balance_loss = (fluid.layers.reduce_sum(positive_loss) +
fluid.layers.reduce_sum(negative_loss)) / (
positive_count + negative_count + eps)
if return_origin:
return balance_loss, loss
return balance_loss
def DiceLoss(pred, gt, mask, weights=None, eps=1e-6):
"""
DiceLoss function.
"""
assert pred.shape == gt.shape
assert pred.shape == mask.shape
if weights is not None:
assert weights.shape == mask.shape
mask = weights * mask
intersection = fluid.layers.reduce_sum(pred * gt * mask)
union = fluid.layers.reduce_sum(pred * mask) + fluid.layers.reduce_sum(
gt * mask) + eps
loss = 1 - 2.0 * intersection / union
assert loss <= 1
return loss
def MaskL1Loss(pred, gt, mask, eps=1e-6):
"""
Mask L1 Loss
"""
loss = fluid.layers.reduce_sum((fluid.layers.abs(pred - gt) * mask)) / (
fluid.layers.reduce_sum(mask) + eps)
loss = fluid.layers.reduce_mean(loss)
return loss
#copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from .det_basic_loss import BalanceLoss, MaskL1Loss, DiceLoss
class DBLoss(object):
"""
Differentiable Binarization (DB) Loss Function
args:
param (dict): the super paramter for DB Loss
"""
def __init__(self, params):
super(DBLoss, self).__init__()
self.balance_loss = params['balance_loss']
self.main_loss_type = params['main_loss_type']
self.alpha = params['alpha']
self.beta = params['beta']
self.ohem_ratio = params['ohem_ratio']
def __call__(self, predicts, labels):
label_shrink_map = labels['shrink_map']
label_shrink_mask = labels['shrink_mask']
label_threshold_map = labels['threshold_map']
label_threshold_mask = labels['threshold_mask']
pred = predicts['maps']
shrink_maps = pred[:, 0, :, :]
threshold_maps = pred[:, 1, :, :]
binary_maps = pred[:, 2, :, :]
loss_shrink_maps = BalanceLoss(
shrink_maps,
label_shrink_map,
label_shrink_mask,
balance_loss=self.balance_loss,
main_loss_type=self.main_loss_type,
negative_ratio=self.ohem_ratio)
loss_threshold_maps = MaskL1Loss(threshold_maps, label_threshold_map,
label_threshold_mask)
loss_binary_maps = DiceLoss(binary_maps, label_shrink_map,
label_shrink_mask)
loss_shrink_maps = self.alpha * loss_shrink_maps
loss_threshold_maps = self.beta * loss_threshold_maps
loss_all = loss_shrink_maps + loss_threshold_maps\
+ loss_binary_maps
losses = {'total_loss':loss_all,\
"loss_shrink_maps":loss_shrink_maps,\
"loss_threshold_maps":loss_threshold_maps,\
"loss_binary_maps":loss_binary_maps}
return losses
#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle.fluid as fluid
class EASTLoss(object):
"""
EAST Loss function
"""
def __init__(self, params=None):
super(EASTLoss, self).__init__()
def __call__(self, predicts, labels):
f_score = predicts['f_score']
f_geo = predicts['f_geo']
l_score = labels['score']
l_geo = labels['geo']
l_mask = labels['mask']
##dice_loss
intersection = fluid.layers.reduce_sum(f_score * l_score * l_mask)
union = fluid.layers.reduce_sum(f_score * l_mask)\
+ fluid.layers.reduce_sum(l_score * l_mask)
dice_loss = 1 - 2 * intersection / (union + 1e-5)
#smoooth_l1_loss
channels = 8
l_geo_split = fluid.layers.split(
l_geo, num_or_sections=channels + 1, dim=1)
f_geo_split = fluid.layers.split(f_geo, num_or_sections=channels, dim=1)
smooth_l1 = 0
for i in range(0, channels):
geo_diff = l_geo_split[i] - f_geo_split[i]
abs_geo_diff = fluid.layers.abs(geo_diff)
smooth_l1_sign = fluid.layers.less_than(abs_geo_diff, l_score)
smooth_l1_sign = fluid.layers.cast(smooth_l1_sign, dtype='float32')
in_loss = abs_geo_diff * abs_geo_diff * smooth_l1_sign + \
(abs_geo_diff - 0.5) * (1.0 - smooth_l1_sign)
out_loss = l_geo_split[-1] / channels * in_loss * l_score
smooth_l1 += out_loss
smooth_l1_loss = fluid.layers.reduce_mean(smooth_l1 * l_score)
dice_loss = dice_loss * 0.01
total_loss = dice_loss + smooth_l1_loss
losses = {'total_loss':total_loss, "dice_loss":dice_loss,\
"smooth_l1_loss":smooth_l1_loss}
return losses
#copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import paddle
import paddle.fluid as fluid
from paddle.fluid.param_attr import ParamAttr
import numpy as np
class AttentionLoss(object):
def __init__(self, params):
super(AttentionLoss, self).__init__()
self.char_num = params['char_num']
def __call__(self, predicts, labels):
predict = predicts['predict']
label_out = labels['label_out']
label_out = fluid.layers.cast(x=label_out, dtype='int64')
cost = fluid.layers.cross_entropy(input=predict, label=label_out)
sum_cost = fluid.layers.reduce_sum(cost)
return sum_cost
#copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import paddle
import paddle.fluid as fluid
class CTCLoss(object):
def __init__(self, params):
super(CTCLoss, self).__init__()
self.char_num = params['char_num']
def __call__(self, predicts, labels):
predict = predicts['predict']
label = labels['label']
cost = fluid.layers.warpctc(
input=predict, label=label, blank=self.char_num, norm_by_times=True)
sum_cost = fluid.layers.reduce_sum(cost)
return sum_cost
此差异已折叠。
#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle.fluid as fluid
def AdamDecay(params, parameter_list=None):
"""
define optimizer function
args:
params(dict): the super parameters
parameter_list (list): list of Variable names to update to minimize loss
return:
"""
base_lr = params['base_lr']
beta1 = params['beta1']
beta2 = params['beta2']
optimizer = fluid.optimizer.Adam(
learning_rate=base_lr,
beta1=beta1,
beta2=beta2,
parameter_list=parameter_list)
return optimizer
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
import paddle.fluid as fluid
import numpy as np
import string
import cv2
from shapely.geometry import Polygon
import pyclipper
class DBPostProcess(object):
"""
The post process for Differentiable Binarization (DB).
"""
def __init__(self, params):
self.thresh = params['thresh']
self.box_thresh = params['box_thresh']
self.max_candidates = params['max_candidates']
self.min_size = 3
def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
'''
_bitmap: single map with shape (1, H, W),
whose values are binarized as {0, 1}
'''
bitmap = _bitmap
height, width = bitmap.shape
# img, contours, _ = cv2.findContours((bitmap * 255).astype(np.uint8), cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
contours, _ = cv2.findContours((bitmap * 255).astype(np.uint8),
cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
num_contours = min(len(contours), self.max_candidates)
boxes = np.zeros((num_contours, 4, 2), dtype=np.int16)
scores = np.zeros((num_contours, ), dtype=np.float32)
for index in range(num_contours):
contour = contours[index]
points, sside = self.get_mini_boxes(contour)
if sside < self.min_size:
continue
points = np.array(points)
score = self.box_score_fast(pred, points.reshape(-1, 2))
if self.box_thresh > score:
continue
box = self.unclip(points).reshape(-1, 1, 2)
box, sside = self.get_mini_boxes(box)
if sside < self.min_size + 2:
continue
box = np.array(box)
if not isinstance(dest_width, int):
dest_width = dest_width.item()
dest_height = dest_height.item()
box[:, 0] = np.clip(
np.round(box[:, 0] / width * dest_width), 0, dest_width)
box[:, 1] = np.clip(
np.round(box[:, 1] / height * dest_height), 0, dest_height)
boxes[index, :, :] = box.astype(np.int16)
scores[index] = score
return boxes, scores
def unclip(self, box, unclip_ratio=1.5):
poly = Polygon(box)
distance = poly.area * unclip_ratio / poly.length
offset = pyclipper.PyclipperOffset()
offset.AddPath(box, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
expanded = np.array(offset.Execute(distance))
return expanded
def get_mini_boxes(self, contour):
bounding_box = cv2.minAreaRect(contour)
points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0])
index_1, index_2, index_3, index_4 = 0, 1, 2, 3
if points[1][1] > points[0][1]:
index_1 = 0
index_4 = 1
else:
index_1 = 1
index_4 = 0
if points[3][1] > points[2][1]:
index_2 = 2
index_3 = 3
else:
index_2 = 3
index_3 = 2
box = [
points[index_1], points[index_2], points[index_3], points[index_4]
]
return box, min(bounding_box[1])
def box_score_fast(self, bitmap, _box):
h, w = bitmap.shape[:2]
box = _box.copy()
xmin = np.clip(np.floor(box[:, 0].min()).astype(np.int), 0, w - 1)
xmax = np.clip(np.ceil(box[:, 0].max()).astype(np.int), 0, w - 1)
ymin = np.clip(np.floor(box[:, 1].min()).astype(np.int), 0, h - 1)
ymax = np.clip(np.ceil(box[:, 1].max()).astype(np.int), 0, h - 1)
mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8)
box[:, 0] = box[:, 0] - xmin
box[:, 1] = box[:, 1] - ymin
cv2.fillPoly(mask, box.reshape(1, -1, 2).astype(np.int32), 1)
return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0]
def __call__(self, outs_dict, ratio_list):
pred = outs_dict['maps']
pred = pred[:, 0, :, :]
segmentation = pred > self.thresh
boxes_batch = []
for batch_index in range(pred.shape[0]):
height, width = pred.shape[-2:]
tmp_boxes, tmp_scores = self.boxes_from_bitmap(
pred[batch_index], segmentation[batch_index], width, height)
boxes = []
for k in range(len(tmp_boxes)):
if tmp_scores[k] > self.box_thresh:
boxes.append(tmp_boxes[k])
if len(boxes) > 0:
boxes = np.array(boxes)
ratio_h, ratio_w = ratio_list[batch_index]
boxes[:, :, 0] = boxes[:, :, 0] / ratio_w
boxes[:, :, 1] = boxes[:, :, 1] / ratio_h
boxes_batch.append(boxes)
return boxes_batch
此差异已折叠。
此差异已折叠。
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
此差异已折叠。
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import sys
import paddle.fluid as fluid
import logging
logger = logging.getLogger(__name__)
def check_config_params(config, config_name, params):
for param in params:
if param not in config:
err = "param %s didn't find in %s!" % (param, config_name)
assert False, err
return
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
<paddle.fluid.core_avx.ProgramDesc object at 0x10d15fab0>
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册