提交 bb779851 编写于 作者: T tink2123

fix infer_rec and add ic15_dict

上级 e3388a24
...@@ -15,8 +15,9 @@ Global: ...@@ -15,8 +15,9 @@ Global:
character_dict_path: ./ppocr/utils/ppocr_keys_v1.txt character_dict_path: ./ppocr/utils/ppocr_keys_v1.txt
loss_type: ctc loss_type: ctc
reader_yml: ./configs/rec/rec_chinese_reader.yml reader_yml: ./configs/rec/rec_chinese_reader.yml
pretrain_weights: pretrain_weights: best_accuracy
infer_img: ./infer_img
Architecture: Architecture:
function: ppocr.modeling.architectures.rec_model,RecModel function: ppocr.modeling.architectures.rec_model,RecModel
......
TrainReader: TrainReader:
reader_function: ppocr.data.rec.dataset_traversal,SimpleReader reader_function: ppocr.data.rec.dataset_traversal,SimpleReader
num_workers: 8 num_workers: 8
img_set_dir: . img_set_dir: ./train_data
label_file_path: ./train_data/hard_label.txt label_file_path: ./train_data/gt_train.txt
EvalReader: EvalReader:
reader_function: ppocr.data.rec.dataset_traversal,SimpleReader reader_function: ppocr.data.rec.dataset_traversal,SimpleReader
img_set_dir: . img_set_dir: ./train_data
label_file_path: ./train_data/label_val_all.txt label_file_path: ./train_data/gt_test.txt
TestReader: TestReader:
reader_function: ppocr.data.rec.dataset_traversal,SimpleReader reader_function: ppocr.data.rec.dataset_traversal,SimpleReader
infer_img: ./infer_img
...@@ -143,8 +143,9 @@ class SimpleReader(object): ...@@ -143,8 +143,9 @@ class SimpleReader(object):
self.num_workers = 1 self.num_workers = 1
else: else:
self.num_workers = params['num_workers'] self.num_workers = params['num_workers']
self.img_set_dir = params['img_set_dir'] if params['mode'] != 'test':
self.label_file_path = params['label_file_path'] self.img_set_dir = params['img_set_dir']
self.label_file_path = params['label_file_path']
self.char_ops = params['char_ops'] self.char_ops = params['char_ops']
self.image_shape = params['image_shape'] self.image_shape = params['image_shape']
self.loss_type = params['loss_type'] self.loss_type = params['loss_type']
...@@ -164,10 +165,17 @@ class SimpleReader(object): ...@@ -164,10 +165,17 @@ class SimpleReader(object):
def sample_iter_reader(): def sample_iter_reader():
if self.mode == 'test': if self.mode == 'test':
print("infer_img:", self.infer_img) image_file_list = []
img = cv2.imread(self.infer_img) if os.path.isfile(self.infer_img):
norm_img = process_image(img, self.image_shape) image_file_list = [self.infer_img]
yield norm_img elif os.path.isdir(self.infer_img):
for single_file in os.listdir(self.infer_img):
if single_file.endswith('png') or single_file.endswith('jpg'):
image_file_list.append(os.path.join(self.infer_img, single_file))
for single_img in image_file_list:
img = cv2.imread(single_img)
norm_img = process_image(img, self.image_shape)
yield norm_img
with open(self.label_file_path, "rb") as fin: with open(self.label_file_path, "rb") as fin:
label_infor_list = fin.readlines() label_infor_list = fin.readlines()
img_num = len(label_infor_list) img_num = len(label_infor_list)
...@@ -179,6 +187,7 @@ class SimpleReader(object): ...@@ -179,6 +187,7 @@ class SimpleReader(object):
img_path = self.img_set_dir + "/" + substr[0] img_path = self.img_set_dir + "/" + substr[0]
img = cv2.imread(img_path) img = cv2.imread(img_path)
if img is None: if img is None:
logger.info("{} does not exist!".format(img_path))
continue continue
label = substr[1] label = substr[1]
outs = process_image(img, self.image_shape, label, outs = process_image(img, self.image_shape, label,
...@@ -198,4 +207,6 @@ class SimpleReader(object): ...@@ -198,4 +207,6 @@ class SimpleReader(object):
if len(batch_outs) != 0: if len(batch_outs) != 0:
yield batch_outs yield batch_outs
return batch_iter_reader if self.mode != 'test':
return batch_iter_reader
return sample_iter_reader
J
O
I
N
T
y
o
u
r
s
e
l
f
1
5
4
9
7
2
8
0
F
m
P
A
B
L
C
K
S
R
E
Y
U
p
d
g
a
t
i
n
h
W
D
v
H
V
G
w
M
!
k
c
.
(
)
X
b
-
Q
x
Z
?
@
3
/
%
$
,
'
:
z
&
j
6
+
[
]
;
#
q
\
´
É
=
#. /paddle/set_env.sh↩ #. /paddle/set_env.sh↩
export CUDA_VISIBLE_DEVICES="0,1,2,3" export CUDA_VISIBLE_DEVICES="0,1,2,3"
export PYTHONPATH=$PYTHONPATH:.↩ export FLAGS_fraction_of_gpu_memory_to_use=1.0
export FLAGS_fraction_of_gpu_memory_to_use=1.0↩ python_bin_dir="/opt/_internal/cpython-3.7.0/bin/"
alias python=$python_bin_dir"python3.7"
python_bin_dir="/opt/_internal/cpython-3.7.0/bin/" alias pip=$python_bin_dir"pip3.7"
alias python=$python_bin_dir"python3.7" alias ipython=$python_bin_dir"ipython3"
alias pip=$python_bin_dir"pip3.7" export LD_LIBRARY_PATH=/opt/_internal/cpython-3.7.0/lib:$LD_LIBRARY_PATH
alias ipython=$python_bin_dir"ipython3" export PYTHONPATH=$PYTHONPATH:.
export LD_LIBRARY_PATH=/opt/_internal/cpython-3.7.0/lib:$LD_LIBRARY_PATH ldconfig
export PYTHONPATH=$PYTHONPATH:.↩
ldconfig↩
...@@ -20,7 +20,7 @@ import os ...@@ -20,7 +20,7 @@ import os
import time import time
import multiprocessing import multiprocessing
import numpy as np import numpy as np
import glob
def set_paddle_flags(**kwargs): def set_paddle_flags(**kwargs):
for key, value in kwargs.items(): for key, value in kwargs.items():
...@@ -79,9 +79,14 @@ def main(): ...@@ -79,9 +79,14 @@ def main():
init_model(config, eval_prog, exe) init_model(config, eval_prog, exe)
blobs = reader_main(config, 'test') blobs = reader_main(config, 'test')()
imgs = next(blobs()) infer_list = os.listdir(config['Global']['infer_img'])
for img in imgs: max_img_num = len(infer_list)
if len(infer_list) == 0:
logger.info("Can not find img in infer_img dir.")
for i in range(max_img_num):
print("infer_img:",infer_list[i])
img = next(blobs)
predict = exe.run(program=eval_prog, predict = exe.run(program=eval_prog,
feed={"image": img}, feed={"image": img},
fetch_list=fetch_varname_list, fetch_list=fetch_varname_list,
...@@ -101,8 +106,8 @@ def main(): ...@@ -101,8 +106,8 @@ def main():
preds_text = preds_text.reshape(-1) preds_text = preds_text.reshape(-1)
preds_text = char_ops.decode(preds_text) preds_text = char_ops.decode(preds_text)
print(preds) print("\t index:",preds)
print(preds_text) print("\t word :",preds_text)
# save for inference model # save for inference model
target_var = [] target_var = []
......
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册