未验证 提交 229265e6 编写于 作者: M MissPenguin 提交者: GitHub

Merge pull request #538 from MissPenguin/develop

add sast code
Global:
algorithm: SAST
use_gpu: true
epoch_num: 2000
log_smooth_window: 20
print_batch_step: 2
save_model_dir: ./output/det_sast/
save_epoch_step: 20
eval_batch_step: 5000
train_batch_size_per_card: 8
test_batch_size_per_card: 8
image_shape: [3, 512, 512]
reader_yml: ./configs/det/det_sast_icdar15_reader.yml
pretrain_weights: ./pretrain_models/ResNet50_vd_ssld_pretrained/
save_res_path: ./output/det_sast/predicts_sast.txt
checkpoints:
save_inference_dir:
Architecture:
function: ppocr.modeling.architectures.det_model,DetModel
Backbone:
function: ppocr.modeling.backbones.det_resnet_vd_sast,ResNet
layers: 50
Head:
function: ppocr.modeling.heads.det_sast_head,SASTHead
model_name: large
only_fpn_up: False
# with_cab: False
with_cab: True
Loss:
function: ppocr.modeling.losses.det_sast_loss,SASTLoss
Optimizer:
function: ppocr.optimizer,RMSProp
base_lr: 0.001
decay:
function: piecewise_decay
boundaries: [30000, 50000, 80000, 100000, 150000]
decay_rate: 0.3
PostProcess:
function: ppocr.postprocess.sast_postprocess,SASTPostProcess
score_thresh: 0.5
sample_pts_num: 2
nms_thresh: 0.2
expand_scale: 1.0
shrink_ratio_of_width: 0.3
\ No newline at end of file
Global:
algorithm: SAST
use_gpu: true
epoch_num: 2000
log_smooth_window: 20
print_batch_step: 2
save_model_dir: ./output/det_sast/
save_epoch_step: 20
eval_batch_step: 5000
train_batch_size_per_card: 8
test_batch_size_per_card: 1
image_shape: [3, 512, 512]
reader_yml: ./configs/det/det_sast_totaltext_reader.yml
pretrain_weights: ./pretrain_models/ResNet50_vd_ssld_pretrained/
save_res_path: ./output/det_sast/predicts_sast.txt
checkpoints:
save_inference_dir:
Architecture:
function: ppocr.modeling.architectures.det_model,DetModel
Backbone:
function: ppocr.modeling.backbones.det_resnet_vd_sast,ResNet
layers: 50
Head:
function: ppocr.modeling.heads.det_sast_head,SASTHead
model_name: large
only_fpn_up: False
# with_cab: False
with_cab: True
Loss:
function: ppocr.modeling.losses.det_sast_loss,SASTLoss
Optimizer:
function: ppocr.optimizer,RMSProp
base_lr: 0.001
decay:
function: piecewise_decay
boundaries: [30000, 50000, 80000, 100000, 150000]
decay_rate: 0.3
PostProcess:
function: ppocr.postprocess.sast_postprocess,SASTPostProcess
score_thresh: 0.5
sample_pts_num: 6
nms_thresh: 0.2
expand_scale: 1.2
shrink_ratio_of_width: 0.2
\ No newline at end of file
TrainReader:
reader_function: ppocr.data.det.dataset_traversal,TrainReader
process_function: ppocr.data.det.sast_process,SASTProcessTrain
num_workers: 8
img_set_dir: ./train_data/
label_file_path: [./train_data/icdar13/train_label_json.txt, ./train_data/icdar15/train_label_json.txt, ./train_data/icdar17_mlt_latin/train_label_json.txt, ./train_data/coco_text_icdar_4pts/train_label_json.txt]
data_ratio_list: [0.1, 0.45, 0.3, 0.15]
min_crop_side_ratio: 0.3
min_crop_size: 24
min_text_size: 4
max_text_size: 512
EvalReader:
reader_function: ppocr.data.det.dataset_traversal,EvalTestReader
process_function: ppocr.data.det.sast_process,SASTProcessTest
img_set_dir: ./train_data/icdar2015/text_localization/
label_file_path: ./train_data/icdar2015/text_localization/test_icdar2015_label.txt
max_side_len: 1536
TestReader:
reader_function: ppocr.data.det.dataset_traversal,EvalTestReader
process_function: ppocr.data.det.sast_process,SASTProcessTest
infer_img:
img_set_dir: ./train_data/icdar2015/text_localization/
label_file_path: ./train_data/icdar2015/text_localization/test_icdar2015_label.txt
do_eval: True
TrainReader:
reader_function: ppocr.data.det.dataset_traversal,TrainReader
process_function: ppocr.data.det.sast_process,SASTProcessTrain
num_workers: 8
img_set_dir: ./train_data/
label_file_path: [./train_data/art_latin_icdar_14pt/train_no_tt_test/train_label_json.txt, ./train_data/total_text_icdar_14pt/train/train_label_json.txt]
data_ratio_list: [0.5, 0.5]
min_crop_side_ratio: 0.3
min_crop_size: 24
min_text_size: 4
max_text_size: 512
EvalReader:
reader_function: ppocr.data.det.dataset_traversal,EvalTestReader
process_function: ppocr.data.det.sast_process,SASTProcessTest
img_set_dir: ./train_data/afs/
label_file_path: ./train_data/afs/total_text/test_label_json.txt
max_side_len: 768
TestReader:
reader_function: ppocr.data.det.dataset_traversal,EvalTestReader
process_function: ppocr.data.det.sast_process,SASTProcessTest
infer_img:
max_side_len: 768
......@@ -31,22 +31,27 @@ class TrainReader(object):
def __init__(self, params):
self.num_workers = params['num_workers']
self.label_file_path = params['label_file_path']
print(self.label_file_path)
self.use_mul_data = False
if isinstance(self.label_file_path, list):
self.use_mul_data = True
self.data_ratio_list = params['data_ratio_list']
self.batch_size = params['train_batch_size_per_card']
assert 'process_function' in params,\
"absence process_function in Reader"
self.process = create_module(params['process_function'])(params)
def __call__(self, process_id):
with open(self.label_file_path, "rb") as fin:
label_infor_list = fin.readlines()
img_num = len(label_infor_list)
img_id_list = list(range(img_num))
if sys.platform == "win32" and self.num_workers != 1:
print("multiprocess is not fully compatible with Windows."
"num_workers will be 1.")
self.num_workers = 1
def sample_iter_reader():
with open(self.label_file_path, "rb") as fin:
label_infor_list = fin.readlines()
img_num = len(label_infor_list)
img_id_list = list(range(img_num))
random.shuffle(img_id_list)
if sys.platform == "win32" and self.num_workers != 1:
print("multiprocess is not fully compatible with Windows."
"num_workers will be 1.")
self.num_workers = 1
for img_id in range(process_id, img_num, self.num_workers):
label_infor = label_infor_list[img_id_list[img_id]]
outs = self.process(label_infor)
......@@ -54,13 +59,64 @@ class TrainReader(object):
continue
yield outs
def sample_iter_reader_mul():
batch_size = 1000
data_source_list = self.label_file_path
batch_size_list = list(map(int, [max(1.0, batch_size * x) for x in self.data_ratio_list]))
print(self.data_ratio_list, batch_size_list)
data_filename_list, data_size_list, fetch_record_list = [], [], []
for data_source in data_source_list:
image_files = open(data_source, "rb").readlines()
random.shuffle(image_files)
data_filename_list.append(image_files)
data_size_list.append(len(image_files))
fetch_record_list.append(0)
image_batch, poly_batch = [], []
# get a batch of img_fns and poly_fns
for i in range(0, len(batch_size_list)):
bs = batch_size_list[i]
ds = data_size_list[i]
image_names = data_filename_list[i]
fetch_record = fetch_record_list[i]
data_path = data_source_list[i]
for j in range(fetch_record, fetch_record + bs):
index = j % ds
image_batch.append(image_names[index])
if (fetch_record + bs) > ds:
fetch_record_list[i] = 0
random.shuffle(data_filename_list[i])
else:
fetch_record_list[i] = fetch_record + bs
if sys.platform == "win32":
print("multiprocess is not fully compatible with Windows."
"num_workers will be 1.")
self.num_workers = 1
for label_infor in image_batch:
outs = self.process(label_infor)
if outs is None:
continue
yield outs
def batch_iter_reader():
batch_outs = []
for outs in sample_iter_reader():
batch_outs.append(outs)
if len(batch_outs) == self.batch_size:
yield batch_outs
batch_outs = []
if self.use_mul_data:
print("Sample date from multiple datasets!")
for outs in sample_iter_reader_mul():
batch_outs.append(outs)
if len(batch_outs) == self.batch_size:
yield batch_outs
batch_outs = []
else:
for outs in sample_iter_reader():
batch_outs.append(outs)
if len(batch_outs) == self.batch_size:
yield batch_outs
batch_outs = []
return batch_iter_reader
......
此差异已折叠。
......@@ -97,6 +97,24 @@ class DetModel(object):
'shrink_mask':shrink_mask,\
'threshold_map':threshold_map,\
'threshold_mask':threshold_mask}
elif self.algorithm == "SAST":
input_score = fluid.layers.data(
name='score', shape=[1, 128, 128], dtype='float32')
input_border = fluid.layers.data(
name='border', shape=[5, 128, 128], dtype='float32')
input_mask = fluid.layers.data(
name='mask', shape=[1, 128, 128], dtype='float32')
input_tvo = fluid.layers.data(
# name='tvo', shape=[5, 128, 128], dtype='float32')
name='tvo', shape=[9, 128, 128], dtype='float32')
input_tco = fluid.layers.data(
name='tco', shape=[3, 128, 128], dtype='float32')
feed_list = [image, input_score, input_border, input_mask, input_tvo, input_tco]
labels = {'input_score': input_score,\
'input_border': input_border,\
'input_mask': input_mask,\
'input_tvo': input_tvo,\
'input_tco': input_tco}
loader = fluid.io.DataLoader.from_generator(
feed_list=feed_list,
capacity=64,
......
#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle.fluid as fluid
from paddle.fluid.param_attr import ParamAttr
__all__ = ["ResNet"]
class ResNet(object):
def __init__(self, params):
"""
the Resnet backbone network for detection module.
Args:
params(dict): the super parameters for network build
"""
self.layers = params['layers']
supported_layers = [18, 34, 50, 101, 152]
assert self.layers in supported_layers, \
"supported layers are {} but input layer is {}".format(supported_layers, self.layers)
self.is_3x3 = True
def __call__(self, input):
layers = self.layers
is_3x3 = self.is_3x3
# if layers == 18:
# depth = [2, 2, 2, 2]
# elif layers == 34 or layers == 50:
# depth = [3, 4, 6, 3]
# elif layers == 101:
# depth = [3, 4, 23, 3]
# elif layers == 152:
# depth = [3, 8, 36, 3]
# elif layers == 200:
# depth = [3, 12, 48, 3]
# num_filters = [64, 128, 256, 512]
# outs = []
if layers == 18:
depth = [2, 2, 2, 2]#, 3, 3]
elif layers == 34 or layers == 50:
#depth = [3, 4, 6, 3]#, 3, 3]
depth = [3, 4, 6, 3, 3]#, 3]
elif layers == 101:
depth = [3, 4, 23, 3]#, 3, 3]
elif layers == 152:
depth = [3, 8, 36, 3]#, 3, 3]
num_filters = [64, 128, 256, 512, 512]#, 512]
blocks = {}
idx = 'block_0'
blocks[idx] = input
if is_3x3 == False:
conv = self.conv_bn_layer(
input=input,
num_filters=64,
filter_size=7,
stride=2,
act='relu')
else:
conv = self.conv_bn_layer(
input=input,
num_filters=32,
filter_size=3,
stride=2,
act='relu',
name='conv1_1')
conv = self.conv_bn_layer(
input=conv,
num_filters=32,
filter_size=3,
stride=1,
act='relu',
name='conv1_2')
conv = self.conv_bn_layer(
input=conv,
num_filters=64,
filter_size=3,
stride=1,
act='relu',
name='conv1_3')
idx = 'block_1'
blocks[idx] = conv
conv = fluid.layers.pool2d(
input=conv,
pool_size=3,
pool_stride=2,
pool_padding=1,
pool_type='max')
if layers >= 50:
for block in range(len(depth)):
for i in range(depth[block]):
if layers in [101, 152, 200] and block == 2:
if i == 0:
conv_name = "res" + str(block + 2) + "a"
else:
conv_name = "res" + str(block + 2) + "b" + str(i)
else:
conv_name = "res" + str(block + 2) + chr(97 + i)
conv = self.bottleneck_block(
input=conv,
num_filters=num_filters[block],
stride=2 if i == 0 and block != 0 else 1,
if_first=block == i == 0,
name=conv_name)
# outs.append(conv)
idx = 'block_' + str(block + 2)
blocks[idx] = conv
else:
for block in range(len(depth)):
for i in range(depth[block]):
conv_name = "res" + str(block + 2) + chr(97 + i)
conv = self.basic_block(
input=conv,
num_filters=num_filters[block],
stride=2 if i == 0 and block != 0 else 1,
if_first=block == i == 0,
name=conv_name)
# outs.append(conv)
idx = 'block_' + str(block + 2)
blocks[idx] = conv
# return outs
return blocks
def conv_bn_layer(self,
input,
num_filters,
filter_size,
stride=1,
groups=1,
act=None,
name=None):
conv = fluid.layers.conv2d(
input=input,
num_filters=num_filters,
filter_size=filter_size,
stride=stride,
padding=(filter_size - 1) // 2,
groups=groups,
act=None,
param_attr=ParamAttr(name=name + "_weights"),
bias_attr=False)
if name == "conv1":
bn_name = "bn_" + name
else:
bn_name = "bn" + name[3:]
return fluid.layers.batch_norm(
input=conv,
act=act,
param_attr=ParamAttr(name=bn_name + '_scale'),
bias_attr=ParamAttr(bn_name + '_offset'),
moving_mean_name=bn_name + '_mean',
moving_variance_name=bn_name + '_variance')
def conv_bn_layer_new(self,
input,
num_filters,
filter_size,
stride=1,
groups=1,
act=None,
name=None):
pool = fluid.layers.pool2d(
input=input,
pool_size=2,
pool_stride=2,
pool_padding=0,
pool_type='avg',
ceil_mode=True)
conv = fluid.layers.conv2d(
input=pool,
num_filters=num_filters,
filter_size=filter_size,
stride=1,
padding=(filter_size - 1) // 2,
groups=groups,
act=None,
param_attr=ParamAttr(name=name + "_weights"),
bias_attr=False)
if name == "conv1":
bn_name = "bn_" + name
else:
bn_name = "bn" + name[3:]
return fluid.layers.batch_norm(
input=conv,
act=act,
param_attr=ParamAttr(name=bn_name + '_scale'),
bias_attr=ParamAttr(bn_name + '_offset'),
moving_mean_name=bn_name + '_mean',
moving_variance_name=bn_name + '_variance')
def shortcut(self, input, ch_out, stride, name, if_first=False):
ch_in = input.shape[1]
if ch_in != ch_out or stride != 1:
if if_first:
return self.conv_bn_layer(input, ch_out, 1, stride, name=name)
else:
return self.conv_bn_layer_new(
input, ch_out, 1, stride, name=name)
elif if_first:
return self.conv_bn_layer(input, ch_out, 1, stride, name=name)
else:
return input
def bottleneck_block(self, input, num_filters, stride, name, if_first):
conv0 = self.conv_bn_layer(
input=input,
num_filters=num_filters,
filter_size=1,
act='relu',
name=name + "_branch2a")
conv1 = self.conv_bn_layer(
input=conv0,
num_filters=num_filters,
filter_size=3,
stride=stride,
act='relu',
name=name + "_branch2b")
conv2 = self.conv_bn_layer(
input=conv1,
num_filters=num_filters * 4,
filter_size=1,
act=None,
name=name + "_branch2c")
short = self.shortcut(
input,
num_filters * 4,
stride,
if_first=if_first,
name=name + "_branch1")
return fluid.layers.elementwise_add(x=short, y=conv2, act='relu')
def basic_block(self, input, num_filters, stride, name, if_first):
conv0 = self.conv_bn_layer(
input=input,
num_filters=num_filters,
filter_size=3,
act='relu',
stride=stride,
name=name + "_branch2a")
conv1 = self.conv_bn_layer(
input=conv0,
num_filters=num_filters,
filter_size=3,
act=None,
name=name + "_branch2b")
short = self.shortcut(
input,
num_filters,
stride,
if_first=if_first,
name=name + "_branch1")
return fluid.layers.elementwise_add(x=short, y=conv1, act='relu')
#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle.fluid as fluid
from ..common_functions import conv_bn_layer, deconv_bn_layer
from collections import OrderedDict
class SASTHead(object):
"""
SAST:
see arxiv: https://
args:
params(dict): the super parameters for network build
"""
def __init__(self, params):
self.model_name = params['model_name']
self.with_cab = params['with_cab']
def FPN_Up_Fusion(self, blocks):
"""
blocks{}: contain block_2, block_3, block_4, block_5, block_6, block_7 with
1/4, 1/8, 1/16, 1/32, 1/64, 1/128 resolution.
"""
f = [blocks['block_6'], blocks['block_5'], blocks['block_4'], blocks['block_3'], blocks['block_2']]
num_outputs = [256, 256, 192, 192, 128]
g = [None, None, None, None, None]
h = [None, None, None, None, None]
for i in range(5):
h[i] = conv_bn_layer(input=f[i], num_filters=num_outputs[i],
filter_size=1, stride=1, act=None, name='fpn_up_h'+str(i))
for i in range(4):
if i == 0:
g[i] = deconv_bn_layer(input=h[i], num_filters=num_outputs[i + 1], act=None, name='fpn_up_g0')
print("g[{}] shape: {}".format(i, g[i].shape))
else:
g[i] = fluid.layers.elementwise_add(x=g[i - 1], y=h[i])
g[i] = fluid.layers.relu(g[i])
#g[i] = conv_bn_layer(input=g[i], num_filters=num_outputs[i],
# filter_size=1, stride=1, act='relu')
g[i] = conv_bn_layer(input=g[i], num_filters=num_outputs[i],
filter_size=3, stride=1, act='relu', name='fpn_up_g%d_1'%i)
g[i] = deconv_bn_layer(input=g[i], num_filters=num_outputs[i + 1], act=None, name='fpn_up_g%d_2'%i)
print("g[{}] shape: {}".format(i, g[i].shape))
g[4] = fluid.layers.elementwise_add(x=g[3], y=h[4])
g[4] = fluid.layers.relu(g[4])
g[4] = conv_bn_layer(input=g[4], num_filters=num_outputs[4],
filter_size=3, stride=1, act='relu', name='fpn_up_fusion_1')
g[4] = conv_bn_layer(input=g[4], num_filters=num_outputs[4],
filter_size=1, stride=1, act=None, name='fpn_up_fusion_2')
return g[4]
def FPN_Down_Fusion(self, blocks):
"""
blocks{}: contain block_2, block_3, block_4, block_5, block_6, block_7 with
1/4, 1/8, 1/16, 1/32, 1/64, 1/128 resolution.
"""
f = [blocks['block_0'], blocks['block_1'], blocks['block_2']]
num_outputs = [32, 64, 128]
g = [None, None, None]
h = [None, None, None]
for i in range(3):
h[i] = conv_bn_layer(input=f[i], num_filters=num_outputs[i],
filter_size=3, stride=1, act=None, name='fpn_down_h'+str(i))
for i in range(2):
if i == 0:
g[i] = conv_bn_layer(input=h[i], num_filters=num_outputs[i+1], filter_size=3, stride=2, act=None, name='fpn_down_g0')
else:
g[i] = fluid.layers.elementwise_add(x=g[i - 1], y=h[i])
g[i] = fluid.layers.relu(g[i])
g[i] = conv_bn_layer(input=g[i], num_filters=num_outputs[i], filter_size=3, stride=1, act='relu', name='fpn_down_g%d_1'%i)
g[i] = conv_bn_layer(input=g[i], num_filters=num_outputs[i+1], filter_size=3, stride=2, act=None, name='fpn_down_g%d_2'%i)
print("g[{}] shape: {}".format(i, g[i].shape))
g[2] = fluid.layers.elementwise_add(x=g[1], y=h[2])
g[2] = fluid.layers.relu(g[2])
g[2] = conv_bn_layer(input=g[2], num_filters=num_outputs[2],
filter_size=3, stride=1, act='relu', name='fpn_down_fusion_1')
g[2] = conv_bn_layer(input=g[2], num_filters=num_outputs[2],
filter_size=1, stride=1, act=None, name='fpn_down_fusion_2')
return g[2]
def SAST_Header1(self, f_common):
"""Detector header."""
#f_score
f_score = conv_bn_layer(input=f_common, num_filters=64, filter_size=1, stride=1, act='relu', name='f_score1')
f_score = conv_bn_layer(input=f_score, num_filters=64, filter_size=3, stride=1, act='relu', name='f_score2')
f_score = conv_bn_layer(input=f_score, num_filters=128, filter_size=1, stride=1, act='relu', name='f_score3')
f_score = conv_bn_layer(input=f_score, num_filters=1, filter_size=3, stride=1, name='f_score4')
f_score = fluid.layers.sigmoid(f_score)
print("f_score shape: {}".format(f_score.shape))
#f_boder
f_border = conv_bn_layer(input=f_common, num_filters=64, filter_size=1, stride=1, act='relu', name='f_border1')
f_border = conv_bn_layer(input=f_border, num_filters=64, filter_size=3, stride=1, act='relu', name='f_border2')
f_border = conv_bn_layer(input=f_border, num_filters=128, filter_size=1, stride=1, act='relu', name='f_border3')
f_border = conv_bn_layer(input=f_border, num_filters=4, filter_size=3, stride=1, name='f_border4')
print("f_border shape: {}".format(f_border.shape))
return f_score, f_border
def SAST_Header2(self, f_common):
"""Detector header."""
#f_tvo
f_tvo = conv_bn_layer(input=f_common, num_filters=64, filter_size=1, stride=1, act='relu', name='f_tvo1')
f_tvo = conv_bn_layer(input=f_tvo, num_filters=64, filter_size=3, stride=1, act='relu', name='f_tvo2')
f_tvo = conv_bn_layer(input=f_tvo, num_filters=128, filter_size=1, stride=1, act='relu', name='f_tvo3')
f_tvo = conv_bn_layer(input=f_tvo, num_filters=8, filter_size=3, stride=1, name='f_tvo4')
print("f_tvo shape: {}".format(f_tvo.shape))
#f_tco
f_tco = conv_bn_layer(input=f_common, num_filters=64, filter_size=1, stride=1, act='relu', name='f_tco1')
f_tco = conv_bn_layer(input=f_tco, num_filters=64, filter_size=3, stride=1, act='relu', name='f_tco2')
f_tco = conv_bn_layer(input=f_tco, num_filters=128, filter_size=1, stride=1, act='relu', name='f_tco3')
f_tco = conv_bn_layer(input=f_tco, num_filters=2, filter_size=3, stride=1, name='f_tco4')
print("f_tco shape: {}".format(f_tco.shape))
return f_tvo, f_tco
def cross_attention(self, f_common):
"""
"""
f_shape = fluid.layers.shape(f_common)
f_theta = conv_bn_layer(input=f_common, num_filters=128, filter_size=1, stride=1, act='relu', name='f_theta')
f_phi = conv_bn_layer(input=f_common, num_filters=128, filter_size=1, stride=1, act='relu', name='f_phi')
f_g = conv_bn_layer(input=f_common, num_filters=128, filter_size=1, stride=1, act='relu', name='f_g')
### horizon
fh_theta = f_theta
fh_phi = f_phi
fh_g = f_g
#flatten
fh_theta = fluid.layers.transpose(fh_theta, [0, 2, 3, 1])
fh_theta = fluid.layers.reshape(fh_theta, [f_shape[0] * f_shape[2], f_shape[3], 128])
fh_phi = fluid.layers.transpose(fh_phi, [0, 2, 3, 1])
fh_phi = fluid.layers.reshape(fh_phi, [f_shape[0] * f_shape[2], f_shape[3], 128])
fh_g = fluid.layers.transpose(fh_g, [0, 2, 3, 1])
fh_g = fluid.layers.reshape(fh_g, [f_shape[0] * f_shape[2], f_shape[3], 128])
#correlation
fh_attn = fluid.layers.matmul(fh_theta, fluid.layers.transpose(fh_phi, [0, 2, 1]))
#scale
fh_attn = fh_attn / (128 ** 0.5)
fh_attn = fluid.layers.softmax(fh_attn)
#weighted sum
fh_weight = fluid.layers.matmul(fh_attn, fh_g)
fh_weight = fluid.layers.reshape(fh_weight, [f_shape[0], f_shape[2], f_shape[3], 128])
print("fh_weight: {}".format(fh_weight.shape))
fh_weight = fluid.layers.transpose(fh_weight, [0, 3, 1, 2])
fh_weight = conv_bn_layer(input=fh_weight, num_filters=128, filter_size=1, stride=1, name='fh_weight')
#short cut
fh_sc = conv_bn_layer(input=f_common, num_filters=128, filter_size=1, stride=1, name='fh_sc')
f_h = fluid.layers.relu(fh_weight + fh_sc)
######
#vertical
fv_theta = fluid.layers.transpose(f_theta, [0, 1, 3, 2])
fv_phi = fluid.layers.transpose(f_phi, [0, 1, 3, 2])
fv_g = fluid.layers.transpose(f_g, [0, 1, 3, 2])
#flatten
fv_theta = fluid.layers.transpose(fv_theta, [0, 2, 3, 1])
fv_theta = fluid.layers.reshape(fv_theta, [f_shape[0] * f_shape[3], f_shape[2], 128])
fv_phi = fluid.layers.transpose(fv_phi, [0, 2, 3, 1])
fv_phi = fluid.layers.reshape(fv_phi, [f_shape[0] * f_shape[3], f_shape[2], 128])
fv_g = fluid.layers.transpose(fv_g, [0, 2, 3, 1])
fv_g = fluid.layers.reshape(fv_g, [f_shape[0] * f_shape[3], f_shape[2], 128])
#correlation
fv_attn = fluid.layers.matmul(fv_theta, fluid.layers.transpose(fv_phi, [0, 2, 1]))
#scale
fv_attn = fv_attn / (128 ** 0.5)
fv_attn = fluid.layers.softmax(fv_attn)
#weighted sum
fv_weight = fluid.layers.matmul(fv_attn, fv_g)
fv_weight = fluid.layers.reshape(fv_weight, [f_shape[0], f_shape[3], f_shape[2], 128])
print("fv_weight: {}".format(fv_weight.shape))
fv_weight = fluid.layers.transpose(fv_weight, [0, 3, 2, 1])
fv_weight = conv_bn_layer(input=fv_weight, num_filters=128, filter_size=1, stride=1, name='fv_weight')
#short cut
fv_sc = conv_bn_layer(input=f_common, num_filters=128, filter_size=1, stride=1, name='fv_sc')
f_v = fluid.layers.relu(fv_weight + fv_sc)
######
f_attn = fluid.layers.concat([f_h, f_v], axis=1)
f_attn = conv_bn_layer(input=f_attn, num_filters=128, filter_size=1, stride=1, act='relu', name='f_attn')
return f_attn
def __call__(self, blocks, with_cab=False):
for k, v in blocks.items():
print(k, v.shape)
#down fpn
f_down = self.FPN_Down_Fusion(blocks)
print("f_down shape: {}".format(f_down.shape))
#up fpn
f_up = self.FPN_Up_Fusion(blocks)
print("f_up shape: {}".format(f_up.shape))
#fusion
f_common = fluid.layers.elementwise_add(x=f_down, y=f_up)
f_common = fluid.layers.relu(f_common)
print("f_common: {}".format(f_common.shape))
if self.with_cab:
print('enhence f_common with CAB.')
f_common = self.cross_attention(f_common)
f_score, f_border= self.SAST_Header1(f_common)
f_tvo, f_tco = self.SAST_Header2(f_common)
predicts = OrderedDict()
predicts['f_score'] = f_score
predicts['f_border'] = f_border
predicts['f_tvo'] = f_tvo
predicts['f_tco'] = f_tco
return predicts
\ No newline at end of file
#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle.fluid as fluid
class SASTLoss(object):
"""
SAST Loss function
"""
def __init__(self, params=None):
super(SASTLoss, self).__init__()
def __call__(self, predicts, labels):
"""
tcl_pos: N x 128 x 3
tcl_mask: N x 128 x 1
tcl_label: N x X list or LoDTensor
"""
f_score = predicts['f_score']
f_border = predicts['f_border']
f_tvo = predicts['f_tvo']
f_tco = predicts['f_tco']
l_score = labels['input_score']
l_border = labels['input_border']
l_mask = labels['input_mask']
l_tvo = labels['input_tvo']
l_tco = labels['input_tco']
#score_loss
intersection = fluid.layers.reduce_sum(f_score * l_score * l_mask)
union = fluid.layers.reduce_sum(f_score * l_mask) + fluid.layers.reduce_sum(l_score * l_mask)
score_loss = 1.0 - 2 * intersection / (union + 1e-5)
#border loss
l_border_split, l_border_norm = fluid.layers.split(l_border, num_or_sections=[4, 1], dim=1)
f_border_split = f_border
l_border_norm_split = fluid.layers.expand(x=l_border_norm, expand_times=[1, 4, 1, 1])
l_border_score = fluid.layers.expand(x=l_score, expand_times=[1, 4, 1, 1])
l_border_mask = fluid.layers.expand(x=l_mask, expand_times=[1, 4, 1, 1])
border_diff = l_border_split - f_border_split
abs_border_diff = fluid.layers.abs(border_diff)
border_sign = abs_border_diff < 1.0
border_sign = fluid.layers.cast(border_sign, dtype='float32')
border_sign.stop_gradient = True
border_in_loss = 0.5 * abs_border_diff * abs_border_diff * border_sign + \
(abs_border_diff - 0.5) * (1.0 - border_sign)
border_out_loss = l_border_norm_split * border_in_loss
border_loss = fluid.layers.reduce_sum(border_out_loss * l_border_score * l_border_mask) / \
(fluid.layers.reduce_sum(l_border_score * l_border_mask) + 1e-5)
#tvo_loss
l_tvo_split, l_tvo_norm = fluid.layers.split(l_tvo, num_or_sections=[8, 1], dim=1)
f_tvo_split = f_tvo
l_tvo_norm_split = fluid.layers.expand(x=l_tvo_norm, expand_times=[1, 8, 1, 1])
l_tvo_score = fluid.layers.expand(x=l_score, expand_times=[1, 8, 1, 1])
l_tvo_mask = fluid.layers.expand(x=l_mask, expand_times=[1, 8, 1, 1])
#
tvo_geo_diff = l_tvo_split - f_tvo_split
abs_tvo_geo_diff = fluid.layers.abs(tvo_geo_diff)
tvo_sign = abs_tvo_geo_diff < 1.0
tvo_sign = fluid.layers.cast(tvo_sign, dtype='float32')
tvo_sign.stop_gradient = True
tvo_in_loss = 0.5 * abs_tvo_geo_diff * abs_tvo_geo_diff * tvo_sign + \
(abs_tvo_geo_diff - 0.5) * (1.0 - tvo_sign)
tvo_out_loss = l_tvo_norm_split * tvo_in_loss
tvo_loss = fluid.layers.reduce_sum(tvo_out_loss * l_tvo_score * l_tvo_mask) / \
(fluid.layers.reduce_sum(l_tvo_score * l_tvo_mask) + 1e-5)
#tco_loss
l_tco_split, l_tco_norm = fluid.layers.split(l_tco, num_or_sections=[2, 1], dim=1)
f_tco_split = f_tco
l_tco_norm_split = fluid.layers.expand(x=l_tco_norm, expand_times=[1, 2, 1, 1])
l_tco_score = fluid.layers.expand(x=l_score, expand_times=[1, 2, 1, 1])
l_tco_mask = fluid.layers.expand(x=l_mask, expand_times=[1, 2, 1, 1])
#
tco_geo_diff = l_tco_split - f_tco_split
abs_tco_geo_diff = fluid.layers.abs(tco_geo_diff)
tco_sign = abs_tco_geo_diff < 1.0
tco_sign = fluid.layers.cast(tco_sign, dtype='float32')
tco_sign.stop_gradient = True
tco_in_loss = 0.5 * abs_tco_geo_diff * abs_tco_geo_diff * tco_sign + \
(abs_tco_geo_diff - 0.5) * (1.0 - tco_sign)
tco_out_loss = l_tco_norm_split * tco_in_loss
tco_loss = fluid.layers.reduce_sum(tco_out_loss * l_tco_score * l_tco_mask) / \
(fluid.layers.reduce_sum(l_tco_score * l_tco_mask) + 1e-5)
# total loss
tvo_lw, tco_lw = 1.5, 1.5
score_lw, border_lw = 1.0, 1.0
total_loss = score_loss * score_lw + border_loss * border_lw + \
tvo_loss * tvo_lw + tco_loss * tco_lw
losses = {'total_loss':total_loss, "score_loss":score_loss,\
"border_loss":border_loss, 'tvo_loss':tvo_loss, 'tco_loss':tco_loss}
return losses
\ No newline at end of file
......@@ -65,3 +65,44 @@ def AdamDecay(params, parameter_list=None):
regularization=L2Decay(regularization_coeff=l2_decay),
parameter_list=parameter_list)
return optimizer
def RMSProp(params, parameter_list=None):
"""
define optimizer function
args:
params(dict): the super parameters
parameter_list (list): list of Variable names to update to minimize loss
return:
"""
base_lr = params.get("base_lr", 0.001)
l2_decay = params.get("l2_decay", 0.00005)
if 'decay' in params:
supported_decay_mode = ["cosine_decay", "piecewise_decay"]
params = params['decay']
decay_mode = params['function']
assert decay_mode in supported_decay_mode, "Supported decay mode is {}, but got {}".format(
supported_decay_mode, decay_mode)
if decay_mode == "cosine_decay":
step_each_epoch = params['step_each_epoch']
total_epoch = params['total_epoch']
base_lr = fluid.layers.cosine_decay(
learning_rate=base_lr,
step_each_epoch=step_each_epoch,
epochs=total_epoch)
elif decay_mode == "piecewise_decay":
boundaries = params["boundaries"]
decay_rate = params["decay_rate"]
values = [
base_lr * decay_rate**idx
for idx in range(len(boundaries) + 1)
]
base_lr = fluid.layers.piecewise_decay(boundaries, values)
optimizer = fluid.optimizer.RMSProp(
learning_rate=base_lr,
regularization=fluid.regularizer.L2Decay(regularization_coeff=l2_decay))
return optimizer
\ No newline at end of file
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import sys
__dir__ = os.path.dirname(__file__)
sys.path.append(__dir__)
sys.path.append(os.path.join(__dir__, '..'))
import numpy as np
from .locality_aware_nms import nms_locality
# import lanms
import cv2
import time
class SASTPostProcess(object):
"""
The post process for SAST.
"""
def __init__(self, params):
self.score_thresh = params.get('score_thresh', 0.5)
self.nms_thresh = params.get('nms_thresh', 0.2)
self.sample_pts_num = params.get('sample_pts_num', 2)
self.shrink_ratio_of_width = params.get('shrink_ratio_of_width', 0.3)
self.expand_scale = params.get('expand_scale', 1.0)
self.tcl_map_thresh = 0.5
# c++ la-nms is faster, but only support python 3.5
self.is_python35 = False
if sys.version_info.major == 3 and sys.version_info.minor == 5:
self.is_python35 = True
def point_pair2poly(self, point_pair_list):
"""
Transfer vertical point_pairs into poly point in clockwise.
"""
# constract poly
point_num = len(point_pair_list) * 2
point_list = [0] * point_num
for idx, point_pair in enumerate(point_pair_list):
point_list[idx] = point_pair[0]
point_list[point_num - 1 - idx] = point_pair[1]
return np.array(point_list).reshape(-1, 2)
def shrink_quad_along_width(self, quad, begin_width_ratio=0., end_width_ratio=1.):
"""
Generate shrink_quad_along_width.
"""
ratio_pair = np.array([[begin_width_ratio], [end_width_ratio]], dtype=np.float32)
p0_1 = quad[0] + (quad[1] - quad[0]) * ratio_pair
p3_2 = quad[3] + (quad[2] - quad[3]) * ratio_pair
return np.array([p0_1[0], p0_1[1], p3_2[1], p3_2[0]])
def expand_poly_along_width(self, poly, shrink_ratio_of_width=0.3):
"""
expand poly along width.
"""
point_num = poly.shape[0]
left_quad = np.array([poly[0], poly[1], poly[-2], poly[-1]], dtype=np.float32)
left_ratio = -shrink_ratio_of_width * np.linalg.norm(left_quad[0] - left_quad[3]) / \
(np.linalg.norm(left_quad[0] - left_quad[1]) + 1e-6)
left_quad_expand = self.shrink_quad_along_width(left_quad, left_ratio, 1.0)
right_quad = np.array([poly[point_num // 2 - 2], poly[point_num // 2 - 1],
poly[point_num // 2], poly[point_num // 2 + 1]], dtype=np.float32)
right_ratio = 1.0 + \
shrink_ratio_of_width * np.linalg.norm(right_quad[0] - right_quad[3]) / \
(np.linalg.norm(right_quad[0] - right_quad[1]) + 1e-6)
right_quad_expand = self.shrink_quad_along_width(right_quad, 0.0, right_ratio)
poly[0] = left_quad_expand[0]
poly[-1] = left_quad_expand[-1]
poly[point_num // 2 - 1] = right_quad_expand[1]
poly[point_num // 2] = right_quad_expand[2]
return poly
def restore_quad(self, tcl_map, tcl_map_thresh, tvo_map):
"""Restore quad."""
xy_text = np.argwhere(tcl_map[:, :, 0] > tcl_map_thresh)
xy_text = xy_text[:, ::-1] # (n, 2)
# Sort the text boxes via the y axis
xy_text = xy_text[np.argsort(xy_text[:, 1])]
scores = tcl_map[xy_text[:, 1], xy_text[:, 0], 0]
scores = scores[:, np.newaxis]
# Restore
point_num = int(tvo_map.shape[-1] / 2)
assert point_num == 4
tvo_map = tvo_map[xy_text[:, 1], xy_text[:, 0], :]
xy_text_tile = np.tile(xy_text, (1, point_num)) # (n, point_num * 2)
quads = xy_text_tile - tvo_map
return scores, quads, xy_text
def quad_area(self, quad):
"""
compute area of a quad.
"""
edge = [
(quad[1][0] - quad[0][0]) * (quad[1][1] + quad[0][1]),
(quad[2][0] - quad[1][0]) * (quad[2][1] + quad[1][1]),
(quad[3][0] - quad[2][0]) * (quad[3][1] + quad[2][1]),
(quad[0][0] - quad[3][0]) * (quad[0][1] + quad[3][1])
]
return np.sum(edge) / 2.
def nms(self, dets):
if self.is_python35:
import lanms
dets = lanms.merge_quadrangle_n9(dets, self.nms_thresh)
else:
dets = nms_locality(dets, self.nms_thresh)
return dets
def cluster_by_quads_tco(self, tcl_map, tcl_map_thresh, quads, tco_map):
"""
Cluster pixels in tcl_map based on quads.
"""
instance_count = quads.shape[0] + 1 # contain background
instance_label_map = np.zeros(tcl_map.shape[:2], dtype=np.int32)
if instance_count == 1:
return instance_count, instance_label_map
# predict text center
xy_text = np.argwhere(tcl_map[:, :, 0] > tcl_map_thresh)
n = xy_text.shape[0]
xy_text = xy_text[:, ::-1] # (n, 2)
tco = tco_map[xy_text[:, 1], xy_text[:, 0], :] # (n, 2)
pred_tc = xy_text - tco
# get gt text center
m = quads.shape[0]
gt_tc = np.mean(quads, axis=1) # (m, 2)
pred_tc_tile = np.tile(pred_tc[:, np.newaxis, :], (1, m, 1)) # (n, m, 2)
gt_tc_tile = np.tile(gt_tc[np.newaxis, :, :], (n, 1, 1)) # (n, m, 2)
dist_mat = np.linalg.norm(pred_tc_tile - gt_tc_tile, axis=2) # (n, m)
xy_text_assign = np.argmin(dist_mat, axis=1) + 1 # (n,)
instance_label_map[xy_text[:, 1], xy_text[:, 0]] = xy_text_assign
return instance_count, instance_label_map
def estimate_sample_pts_num(self, quad, xy_text):
"""
Estimate sample points number.
"""
eh = (np.linalg.norm(quad[0] - quad[3]) + np.linalg.norm(quad[1] - quad[2])) / 2.0
ew = (np.linalg.norm(quad[0] - quad[1]) + np.linalg.norm(quad[2] - quad[3])) / 2.0
dense_sample_pts_num = max(2, int(ew))
dense_xy_center_line = xy_text[np.linspace(0, xy_text.shape[0] - 1, dense_sample_pts_num,
endpoint=True, dtype=np.float32).astype(np.int32)]
dense_xy_center_line_diff = dense_xy_center_line[1:] - dense_xy_center_line[:-1]
estimate_arc_len = np.sum(np.linalg.norm(dense_xy_center_line_diff, axis=1))
sample_pts_num = max(2, int(estimate_arc_len / eh))
return sample_pts_num
def detect_sast(self, tcl_map, tvo_map, tbo_map, tco_map, ratio_w, ratio_h, src_w, src_h,
shrink_ratio_of_width=0.3, tcl_map_thresh=0.5, offset_expand=1.0, out_strid=4.0):
"""
first resize the tcl_map, tvo_map and tbo_map to the input_size, then restore the polys
"""
# restore quad
scores, quads, xy_text = self.restore_quad(tcl_map, tcl_map_thresh, tvo_map)
dets = np.hstack((quads, scores)).astype(np.float32, copy=False)
dets = self.nms(dets)
if dets.shape[0] == 0:
return []
quads = dets[:, :-1].reshape(-1, 4, 2)
# Compute quad area
quad_areas = []
for quad in quads:
quad_areas.append(-self.quad_area(quad))
# instance segmentation
# instance_count, instance_label_map = cv2.connectedComponents(tcl_map.astype(np.uint8), connectivity=8)
instance_count, instance_label_map = self.cluster_by_quads_tco(tcl_map, tcl_map_thresh, quads, tco_map)
# restore single poly with tcl instance.
poly_list = []
for instance_idx in range(1, instance_count):
xy_text = np.argwhere(instance_label_map == instance_idx)[:, ::-1]
quad = quads[instance_idx - 1]
q_area = quad_areas[instance_idx - 1]
if q_area < 5:
continue
#
len1 = float(np.linalg.norm(quad[0] -quad[1]))
len2 = float(np.linalg.norm(quad[1] -quad[2]))
min_len = min(len1, len2)
if min_len < 3:
continue
# filter small CC
if xy_text.shape[0] <= 0:
continue
# filter low confidence instance
xy_text_scores = tcl_map[xy_text[:, 1], xy_text[:, 0], 0]
if np.sum(xy_text_scores) / quad_areas[instance_idx - 1] < 0.1:
# if np.sum(xy_text_scores) / quad_areas[instance_idx - 1] < 0.05:
continue
# sort xy_text
left_center_pt = np.array([[(quad[0, 0] + quad[-1, 0]) / 2.0,
(quad[0, 1] + quad[-1, 1]) / 2.0]]) # (1, 2)
right_center_pt = np.array([[(quad[1, 0] + quad[2, 0]) / 2.0,
(quad[1, 1] + quad[2, 1]) / 2.0]]) # (1, 2)
proj_unit_vec = (right_center_pt - left_center_pt) / \
(np.linalg.norm(right_center_pt - left_center_pt) + 1e-6)
proj_value = np.sum(xy_text * proj_unit_vec, axis=1)
xy_text = xy_text[np.argsort(proj_value)]
# Sample pts in tcl map
if self.sample_pts_num == 0:
sample_pts_num = self.estimate_sample_pts_num(quad, xy_text)
else:
sample_pts_num = self.sample_pts_num
xy_center_line = xy_text[np.linspace(0, xy_text.shape[0] - 1, sample_pts_num,
endpoint=True, dtype=np.float32).astype(np.int32)]
point_pair_list = []
for x, y in xy_center_line:
# get corresponding offset
offset = tbo_map[y, x, :].reshape(2, 2)
if offset_expand != 1.0:
offset_length = np.linalg.norm(offset, axis=1, keepdims=True)
expand_length = np.clip(offset_length * (offset_expand - 1), a_min=0.5, a_max=3.0)
offset_detal = offset / offset_length * expand_length
offset = offset + offset_detal
# original point
ori_yx = np.array([y, x], dtype=np.float32)
point_pair = (ori_yx + offset)[:, ::-1]* out_strid / np.array([ratio_w, ratio_h]).reshape(-1, 2)
point_pair_list.append(point_pair)
# ndarry: (x, 2), expand poly along width
detected_poly = self.point_pair2poly(point_pair_list)
detected_poly = self.expand_poly_along_width(detected_poly, shrink_ratio_of_width)
detected_poly[:, 0] = np.clip(detected_poly[:, 0], a_min=0, a_max=src_w)
detected_poly[:, 1] = np.clip(detected_poly[:, 1], a_min=0, a_max=src_h)
poly_list.append(detected_poly)
return poly_list
def __call__(self, outs_dict, ratio_list):
score_list = outs_dict['f_score']
border_list = outs_dict['f_border']
tvo_list = outs_dict['f_tvo']
tco_list = outs_dict['f_tco']
img_num = len(ratio_list)
poly_lists = []
for ino in range(img_num):
p_score = score_list[ino].transpose((1,2,0))
p_border = border_list[ino].transpose((1,2,0))
p_tvo = tvo_list[ino].transpose((1,2,0))
p_tco = tco_list[ino].transpose((1,2,0))
# print(p_score.shape, p_border.shape, p_tvo.shape, p_tco.shape)
ratio_h, ratio_w, src_h, src_w = ratio_list[ino]
poly_list = self.detect_sast(p_score, p_tvo, p_border, p_tco, ratio_w, ratio_h, src_w, src_h,
shrink_ratio_of_width=self.shrink_ratio_of_width,
tcl_map_thresh=self.tcl_map_thresh, offset_expand=self.expand_scale)
poly_lists.append(poly_list)
return poly_lists
文件模式从 100755 更改为 100644
......@@ -88,8 +88,8 @@ class DetectionIoUEvaluator(object):
points = gt[n]['points']
# transcription = gt[n]['text']
dontCare = gt[n]['ignore']
points = Polygon(points)
points = points.buffer(0)
# points = Polygon(points)
# points = points.buffer(0)
if not Polygon(points).is_valid or not Polygon(points).is_simple:
continue
......@@ -105,8 +105,8 @@ class DetectionIoUEvaluator(object):
for n in range(len(pred)):
points = pred[n]['points']
points = Polygon(points)
points = points.buffer(0)
# points = Polygon(points)
# points = points.buffer(0)
if not Polygon(points).is_valid or not Polygon(points).is_simple:
continue
......
......@@ -82,10 +82,8 @@ default_config = {'Global': {'debug': False, }}
def load_config(file_path):
"""
Load config from yml/yaml file.
Args:
file_path (str): Path of the config file to be loaded.
Returns: global config
"""
merge_config(default_config)
......@@ -104,10 +102,8 @@ def load_config(file_path):
def merge_config(config):
"""
Merge config into global config.
Args:
config (dict): Config to be merged.
Returns: global config
"""
for key, value in config.items():
......@@ -158,13 +154,11 @@ def build(config, main_prog, startup_prog, mode):
3. create a model
4. create fetchs
5. create an optimizer
Args:
config(dict): config
main_prog(): main program
startup_prog(): startup program
is_train(bool): train or valid
Returns:
dataloader(): a bridge between the model and the data
fetchs(dict): dict of model outputs(included loss and measures)
......@@ -415,7 +409,7 @@ def preprocess():
check_gpu(use_gpu)
alg = config['Global']['algorithm']
assert alg in ['EAST', 'DB', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN']
assert alg in ['EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN']
if alg in ['Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN']:
config['Global']['char_ops'] = CharacterOps(config['Global'])
......@@ -423,7 +417,7 @@ def preprocess():
startup_program = fluid.Program()
train_program = fluid.Program()
if alg in ['EAST', 'DB']:
if alg in ['EAST', 'DB', 'SAST']:
train_alg_type = 'det'
else:
train_alg_type = 'rec'
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册