未验证 提交 055f207f 编写于 作者: M MissPenguin 提交者: GitHub

Merge pull request #1371 from tink2123/update_multi

update multi dic and export
Global: Global:
use_gpu: true use_gpu: True
epoch_num: 500 epoch_num: 500
log_smooth_window: 20 log_smooth_window: 20
print_batch_step: 10 print_batch_step: 10
...@@ -15,7 +15,7 @@ Global: ...@@ -15,7 +15,7 @@ Global:
use_visualdl: False use_visualdl: False
infer_img: infer_img:
# for data or label process # for data or label process
character_dict_path: ppocr/utils/dict/ic15_dict.txt character_dict_path: ppocr/utils/dict/en_dict.txt
character_type: ch character_type: ch
max_text_length: 25 max_text_length: 25
infer_mode: False infer_mode: False
......
Global: Global:
use_gpu: true use_gpu: True
epoch_num: 500 epoch_num: 500
log_smooth_window: 20 log_smooth_window: 20
print_batch_step: 10 print_batch_step: 10
...@@ -19,7 +19,7 @@ Global: ...@@ -19,7 +19,7 @@ Global:
character_type: french character_type: french
max_text_length: 25 max_text_length: 25
infer_mode: False infer_mode: False
use_space_char: True use_space_char: False
Optimizer: Optimizer:
......
Global: Global:
use_gpu: true use_gpu: True
epoch_num: 500 epoch_num: 500
log_smooth_window: 20 log_smooth_window: 20
print_batch_step: 10 print_batch_step: 10
...@@ -19,7 +19,7 @@ Global: ...@@ -19,7 +19,7 @@ Global:
character_type: german character_type: german
max_text_length: 25 max_text_length: 25
infer_mode: False infer_mode: False
use_space_char: True use_space_char: False
Optimizer: Optimizer:
......
Global: Global:
use_gpu: true use_gpu: True
epoch_num: 500 epoch_num: 500
log_smooth_window: 20 log_smooth_window: 20
print_batch_step: 10 print_batch_step: 10
...@@ -19,7 +19,7 @@ Global: ...@@ -19,7 +19,7 @@ Global:
character_type: japan character_type: japan
max_text_length: 25 max_text_length: 25
infer_mode: False infer_mode: False
use_space_char: True use_space_char: False
Optimizer: Optimizer:
......
Global: Global:
use_gpu: true use_gpu: True
epoch_num: 500 epoch_num: 500
log_smooth_window: 20 log_smooth_window: 20
print_batch_step: 10 print_batch_step: 10
...@@ -19,7 +19,7 @@ Global: ...@@ -19,7 +19,7 @@ Global:
character_type: korean character_type: korean
max_text_length: 25 max_text_length: 25
infer_mode: False infer_mode: False
use_space_char: True use_space_char: False
Optimizer: Optimizer:
......
...@@ -81,7 +81,7 @@ cv::Mat Classifier::Run(cv::Mat &img) { ...@@ -81,7 +81,7 @@ cv::Mat Classifier::Run(cv::Mat &img) {
void Classifier::LoadModel(const std::string &model_dir) { void Classifier::LoadModel(const std::string &model_dir) {
AnalysisConfig config; AnalysisConfig config;
config.SetModel(model_dir + "/model", model_dir + "/params"); config.SetModel(model_dir + ".pdmodel", model_dir + ".pdiparams");
if (this->use_gpu_) { if (this->use_gpu_) {
config.EnableUseGpu(this->gpu_mem_, this->gpu_id_); config.EnableUseGpu(this->gpu_mem_, this->gpu_id_);
......
...@@ -18,7 +18,7 @@ namespace PaddleOCR { ...@@ -18,7 +18,7 @@ namespace PaddleOCR {
void DBDetector::LoadModel(const std::string &model_dir) { void DBDetector::LoadModel(const std::string &model_dir) {
AnalysisConfig config; AnalysisConfig config;
config.SetModel(model_dir + "/model", model_dir + "/params"); config.SetModel(model_dir + ".pdmodel", model_dir + ".pdiparams");
if (this->use_gpu_) { if (this->use_gpu_) {
config.EnableUseGpu(this->gpu_mem_, this->gpu_id_); config.EnableUseGpu(this->gpu_mem_, this->gpu_id_);
......
...@@ -103,7 +103,7 @@ void CRNNRecognizer::Run(std::vector<std::vector<std::vector<int>>> boxes, ...@@ -103,7 +103,7 @@ void CRNNRecognizer::Run(std::vector<std::vector<std::vector<int>>> boxes,
void CRNNRecognizer::LoadModel(const std::string &model_dir) { void CRNNRecognizer::LoadModel(const std::string &model_dir) {
AnalysisConfig config; AnalysisConfig config;
config.SetModel(model_dir + "/model", model_dir + "/params"); config.SetModel(model_dir + ".pdmodel", model_dir + ".pdiparams");
if (this->use_gpu_) { if (this->use_gpu_) {
config.EnableUseGpu(this->gpu_mem_, this->gpu_id_); config.EnableUseGpu(this->gpu_mem_, this->gpu_id_);
......
0
1
2
3
4
5
6
7
8
9
a
b
c
d
e
f
g
h
i
j
k
l
m
n
o
p
q
r
s
t
u
v
w
x
y
z
A
B
C
D
E
F
G
H
I
J
K
L
M
N
O
P
Q
R
S
T
U
V
W
X
Y
Z
...@@ -133,3 +133,4 @@ j ...@@ -133,3 +133,4 @@ j
Å Å
$ $
# #
...@@ -124,3 +124,4 @@ z ...@@ -124,3 +124,4 @@ z
å å
æ æ
é é
...@@ -4396,3 +4396,4 @@ z ...@@ -4396,3 +4396,4 @@ z
...@@ -179,7 +179,7 @@ z ...@@ -179,7 +179,7 @@ z
с с
т т
я я
...@@ -3685,3 +3685,4 @@ z ...@@ -3685,3 +3685,4 @@ z
...@@ -39,26 +39,12 @@ def parse_args(): ...@@ -39,26 +39,12 @@ def parse_args():
return parser.parse_args() return parser.parse_args()
class Model(paddle.nn.Layer):
def __init__(self, model):
super(Model, self).__init__()
self.pre_model = model
# Please modify the 'shape' according to actual needs
@to_static(input_spec=[
paddle.static.InputSpec(
shape=[None, 3, 640, 640], dtype='float32')
])
def forward(self, inputs):
x = self.pre_model(inputs)
return x
def main(): def main():
FLAGS = parse_args() FLAGS = parse_args()
config = load_config(FLAGS.config) config = load_config(FLAGS.config)
logger = get_logger() logger = get_logger()
# build post process # build post process
post_process_class = build_post_process(config['PostProcess'], post_process_class = build_post_process(config['PostProcess'],
config['Global']) config['Global'])
...@@ -71,9 +57,16 @@ def main(): ...@@ -71,9 +57,16 @@ def main():
init_model(config, model, logger) init_model(config, model, logger)
model.eval() model.eval()
model = Model(model) save_path = '{}/{}/inference'.format(FLAGS.output_path,
save_path = '{}/{}'.format(FLAGS.output_path,
config['Architecture']['model_type']) config['Architecture']['model_type'])
infer_shape = [3, 32, 100] if config['Architecture'][
'model_type'] != "det" else [3, 640, 640]
model = to_static(
model,
input_spec=[
paddle.static.InputSpec(
shape=[None] + infer_shape, dtype='float32')
])
paddle.jit.save(model, save_path) paddle.jit.save(model, save_path)
logger.info('inference model is saved to {}'.format(save_path)) logger.info('inference model is saved to {}'.format(save_path))
......
...@@ -100,8 +100,8 @@ def create_predictor(args, mode, logger): ...@@ -100,8 +100,8 @@ def create_predictor(args, mode, logger):
if model_dir is None: if model_dir is None:
logger.info("not find {} model file path {}".format(mode, model_dir)) logger.info("not find {} model file path {}".format(mode, model_dir))
sys.exit(0) sys.exit(0)
model_file_path = model_dir + "/model" model_file_path = model_dir + ".pdmodel"
params_file_path = model_dir + "/params" params_file_path = model_dir + ".pdiparams"
if not os.path.exists(model_file_path): if not os.path.exists(model_file_path):
logger.info("not find model file path {}".format(model_file_path)) logger.info("not find model file path {}".format(model_file_path))
sys.exit(0) sys.exit(0)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册