提交 7eeef593 编写于 作者: T tink2123

update multi dic and export

上级 a948584c
Global: Global:
use_gpu: true use_gpu: True
epoch_num: 500 epoch_num: 500
log_smooth_window: 20 log_smooth_window: 20
print_batch_step: 10 print_batch_step: 10
...@@ -15,7 +15,7 @@ Global: ...@@ -15,7 +15,7 @@ Global:
use_visualdl: False use_visualdl: False
infer_img: infer_img:
# for data or label process # for data or label process
character_dict_path: ppocr/utils/dict/ic15_dict.txt character_dict_path: ppocr/utils/dict/en_dict.txt
character_type: ch character_type: ch
max_text_length: 25 max_text_length: 25
infer_mode: False infer_mode: False
......
Global: Global:
use_gpu: true use_gpu: True
epoch_num: 500 epoch_num: 500
log_smooth_window: 20 log_smooth_window: 20
print_batch_step: 10 print_batch_step: 10
...@@ -9,9 +9,9 @@ Global: ...@@ -9,9 +9,9 @@ Global:
eval_batch_step: [0, 2000] eval_batch_step: [0, 2000]
# if pretrained_model is saved in static mode, load_static_weights must set to True # if pretrained_model is saved in static mode, load_static_weights must set to True
cal_metric_during_train: True cal_metric_during_train: True
pretrained_model: pretrained_model:
checkpoints: checkpoints:
save_inference_dir: save_inference_dir:
use_visualdl: False use_visualdl: False
infer_img: infer_img:
# for data or label process # for data or label process
...@@ -19,7 +19,7 @@ Global: ...@@ -19,7 +19,7 @@ Global:
character_type: french character_type: french
max_text_length: 25 max_text_length: 25
infer_mode: False infer_mode: False
use_space_char: True use_space_char: False
Optimizer: Optimizer:
......
Global: Global:
use_gpu: true use_gpu: True
epoch_num: 500 epoch_num: 500
log_smooth_window: 20 log_smooth_window: 20
print_batch_step: 10 print_batch_step: 10
...@@ -19,7 +19,7 @@ Global: ...@@ -19,7 +19,7 @@ Global:
character_type: german character_type: german
max_text_length: 25 max_text_length: 25
infer_mode: False infer_mode: False
use_space_char: True use_space_char: False
Optimizer: Optimizer:
......
Global: Global:
use_gpu: true use_gpu: True
epoch_num: 500 epoch_num: 500
log_smooth_window: 20 log_smooth_window: 20
print_batch_step: 10 print_batch_step: 10
...@@ -19,7 +19,7 @@ Global: ...@@ -19,7 +19,7 @@ Global:
character_type: japan character_type: japan
max_text_length: 25 max_text_length: 25
infer_mode: False infer_mode: False
use_space_char: True use_space_char: False
Optimizer: Optimizer:
......
Global: Global:
use_gpu: true use_gpu: True
epoch_num: 500 epoch_num: 500
log_smooth_window: 20 log_smooth_window: 20
print_batch_step: 10 print_batch_step: 10
...@@ -19,7 +19,7 @@ Global: ...@@ -19,7 +19,7 @@ Global:
character_type: korean character_type: korean
max_text_length: 25 max_text_length: 25
infer_mode: False infer_mode: False
use_space_char: True use_space_char: False
Optimizer: Optimizer:
......
...@@ -81,7 +81,7 @@ cv::Mat Classifier::Run(cv::Mat &img) { ...@@ -81,7 +81,7 @@ cv::Mat Classifier::Run(cv::Mat &img) {
void Classifier::LoadModel(const std::string &model_dir) { void Classifier::LoadModel(const std::string &model_dir) {
AnalysisConfig config; AnalysisConfig config;
config.SetModel(model_dir + "/model", model_dir + "/params"); config.SetModel(model_dir + ".pdmodel", model_dir + ".pdiparams");
if (this->use_gpu_) { if (this->use_gpu_) {
config.EnableUseGpu(this->gpu_mem_, this->gpu_id_); config.EnableUseGpu(this->gpu_mem_, this->gpu_id_);
......
...@@ -18,7 +18,7 @@ namespace PaddleOCR { ...@@ -18,7 +18,7 @@ namespace PaddleOCR {
void DBDetector::LoadModel(const std::string &model_dir) { void DBDetector::LoadModel(const std::string &model_dir) {
AnalysisConfig config; AnalysisConfig config;
config.SetModel(model_dir + "/model", model_dir + "/params"); config.SetModel(model_dir + ".pdmodel", model_dir + ".pdiparams");
if (this->use_gpu_) { if (this->use_gpu_) {
config.EnableUseGpu(this->gpu_mem_, this->gpu_id_); config.EnableUseGpu(this->gpu_mem_, this->gpu_id_);
......
...@@ -103,7 +103,7 @@ void CRNNRecognizer::Run(std::vector<std::vector<std::vector<int>>> boxes, ...@@ -103,7 +103,7 @@ void CRNNRecognizer::Run(std::vector<std::vector<std::vector<int>>> boxes,
void CRNNRecognizer::LoadModel(const std::string &model_dir) { void CRNNRecognizer::LoadModel(const std::string &model_dir) {
AnalysisConfig config; AnalysisConfig config;
config.SetModel(model_dir + "/model", model_dir + "/params"); config.SetModel(model_dir + ".pdmodel", model_dir + ".pdiparams");
if (this->use_gpu_) { if (this->use_gpu_) {
config.EnableUseGpu(this->gpu_mem_, this->gpu_id_); config.EnableUseGpu(this->gpu_mem_, this->gpu_id_);
...@@ -186,4 +186,4 @@ cv::Mat CRNNRecognizer::GetRotateCropImage(const cv::Mat &srcimage, ...@@ -186,4 +186,4 @@ cv::Mat CRNNRecognizer::GetRotateCropImage(const cv::Mat &srcimage,
} }
} }
} // namespace PaddleOCR } // namespace PaddleOCR
\ No newline at end of file
0
1
2
3
4
5
6
7
8
9
a
b
c
d
e
f
g
h
i
j
k
l
m
n
o
p
q
r
s
t
u
v
w
x
y
z
A
B
C
D
E
F
G
H
I
J
K
L
M
N
O
P
Q
R
S
T
U
V
W
X
Y
Z
...@@ -132,4 +132,5 @@ j ...@@ -132,4 +132,5 @@ j
³ ³
Å Å
$ $
# #
\ No newline at end of file
...@@ -123,4 +123,5 @@ z ...@@ -123,4 +123,5 @@ z
â â
å å
æ æ
é é
\ No newline at end of file
...@@ -4395,4 +4395,5 @@ z ...@@ -4395,4 +4395,5 @@ z
\ No newline at end of file
...@@ -179,7 +179,7 @@ z ...@@ -179,7 +179,7 @@ z
с с
т т
я я
...@@ -3684,4 +3684,5 @@ z ...@@ -3684,4 +3684,5 @@ z
\ No newline at end of file
...@@ -39,26 +39,12 @@ def parse_args(): ...@@ -39,26 +39,12 @@ def parse_args():
return parser.parse_args() return parser.parse_args()
class Model(paddle.nn.Layer):
def __init__(self, model):
super(Model, self).__init__()
self.pre_model = model
# Please modify the 'shape' according to actual needs
@to_static(input_spec=[
paddle.static.InputSpec(
shape=[None, 3, 640, 640], dtype='float32')
])
def forward(self, inputs):
x = self.pre_model(inputs)
return x
def main(): def main():
FLAGS = parse_args() FLAGS = parse_args()
config = load_config(FLAGS.config) config = load_config(FLAGS.config)
logger = get_logger() logger = get_logger()
# build post process # build post process
post_process_class = build_post_process(config['PostProcess'], post_process_class = build_post_process(config['PostProcess'],
config['Global']) config['Global'])
...@@ -71,9 +57,16 @@ def main(): ...@@ -71,9 +57,16 @@ def main():
init_model(config, model, logger) init_model(config, model, logger)
model.eval() model.eval()
model = Model(model) save_path = '{}/{}/inference'.format(FLAGS.output_path,
save_path = '{}/{}'.format(FLAGS.output_path, config['Architecture']['model_type'])
config['Architecture']['model_type']) infer_shape = [3, 32, 100] if config['Architecture'][
'model_type'] != "det" else [3, 640, 640]
model = to_static(
model,
input_spec=[
paddle.static.InputSpec(
shape=[None] + infer_shape, dtype='float32')
])
paddle.jit.save(model, save_path) paddle.jit.save(model, save_path)
logger.info('inference model is saved to {}'.format(save_path)) logger.info('inference model is saved to {}'.format(save_path))
......
...@@ -100,8 +100,8 @@ def create_predictor(args, mode, logger): ...@@ -100,8 +100,8 @@ def create_predictor(args, mode, logger):
if model_dir is None: if model_dir is None:
logger.info("not find {} model file path {}".format(mode, model_dir)) logger.info("not find {} model file path {}".format(mode, model_dir))
sys.exit(0) sys.exit(0)
model_file_path = model_dir + "/model" model_file_path = model_dir + ".pdmodel"
params_file_path = model_dir + "/params" params_file_path = model_dir + ".pdiparams"
if not os.path.exists(model_file_path): if not os.path.exists(model_file_path):
logger.info("not find model file path {}".format(model_file_path)) logger.info("not find model file path {}".format(model_file_path))
sys.exit(0) sys.exit(0)
...@@ -230,10 +230,10 @@ def draw_ocr_box_txt(image, ...@@ -230,10 +230,10 @@ def draw_ocr_box_txt(image,
box[2][1], box[3][0], box[3][1] box[2][1], box[3][0], box[3][1]
], ],
outline=color) outline=color)
box_height = math.sqrt((box[0][0] - box[3][0]) ** 2 + (box[0][1] - box[3][ box_height = math.sqrt((box[0][0] - box[3][0])**2 + (box[0][1] - box[3][
1]) ** 2) 1])**2)
box_width = math.sqrt((box[0][0] - box[1][0]) ** 2 + (box[0][1] - box[1][ box_width = math.sqrt((box[0][0] - box[1][0])**2 + (box[0][1] - box[1][
1]) ** 2) 1])**2)
if box_height > 2 * box_width: if box_height > 2 * box_width:
font_size = max(int(box_width * 0.9), 10) font_size = max(int(box_width * 0.9), 10)
font = ImageFont.truetype(font_path, font_size, encoding="utf-8") font = ImageFont.truetype(font_path, font_size, encoding="utf-8")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册