提交 bb779851 编写于 作者: T tink2123

fix infer_rec and add ic15_dict

上级 e3388a24
......@@ -15,8 +15,9 @@ Global:
character_dict_path: ./ppocr/utils/ppocr_keys_v1.txt
loss_type: ctc
reader_yml: ./configs/rec/rec_chinese_reader.yml
pretrain_weights:
pretrain_weights: best_accuracy
infer_img: ./infer_img
Architecture:
function: ppocr.modeling.architectures.rec_model,RecModel
......
TrainReader:
reader_function: ppocr.data.rec.dataset_traversal,SimpleReader
num_workers: 8
img_set_dir: .
label_file_path: ./train_data/hard_label.txt
img_set_dir: ./train_data
label_file_path: ./train_data/gt_train.txt
EvalReader:
reader_function: ppocr.data.rec.dataset_traversal,SimpleReader
img_set_dir: .
label_file_path: ./train_data/label_val_all.txt
img_set_dir: ./train_data
label_file_path: ./train_data/gt_test.txt
TestReader:
reader_function: ppocr.data.rec.dataset_traversal,SimpleReader
infer_img: ./infer_img
......@@ -143,8 +143,9 @@ class SimpleReader(object):
self.num_workers = 1
else:
self.num_workers = params['num_workers']
self.img_set_dir = params['img_set_dir']
self.label_file_path = params['label_file_path']
if params['mode'] != 'test':
self.img_set_dir = params['img_set_dir']
self.label_file_path = params['label_file_path']
self.char_ops = params['char_ops']
self.image_shape = params['image_shape']
self.loss_type = params['loss_type']
......@@ -164,10 +165,17 @@ class SimpleReader(object):
def sample_iter_reader():
if self.mode == 'test':
print("infer_img:", self.infer_img)
img = cv2.imread(self.infer_img)
norm_img = process_image(img, self.image_shape)
yield norm_img
image_file_list = []
if os.path.isfile(self.infer_img):
image_file_list = [self.infer_img]
elif os.path.isdir(self.infer_img):
for single_file in os.listdir(self.infer_img):
if single_file.endswith('png') or single_file.endswith('jpg'):
image_file_list.append(os.path.join(self.infer_img, single_file))
for single_img in image_file_list:
img = cv2.imread(single_img)
norm_img = process_image(img, self.image_shape)
yield norm_img
with open(self.label_file_path, "rb") as fin:
label_infor_list = fin.readlines()
img_num = len(label_infor_list)
......@@ -179,6 +187,7 @@ class SimpleReader(object):
img_path = self.img_set_dir + "/" + substr[0]
img = cv2.imread(img_path)
if img is None:
logger.info("{} does not exist!".format(img_path))
continue
label = substr[1]
outs = process_image(img, self.image_shape, label,
......@@ -198,4 +207,6 @@ class SimpleReader(object):
if len(batch_outs) != 0:
yield batch_outs
return batch_iter_reader
if self.mode != 'test':
return batch_iter_reader
return sample_iter_reader
J
O
I
N
T
y
o
u
r
s
e
l
f
1
5
4
9
7
2
8
0
F
m
P
A
B
L
C
K
S
R
E
Y
U
p
d
g
a
t
i
n
h
W
D
v
H
V
G
w
M
!
k
c
.
(
)
X
b
-
Q
x
Z
?
@
3
/
%
$
,
'
:
z
&
j
6
+
[
]
;
#
q
\
´
É
=
#. /paddle/set_env.sh↩
export CUDA_VISIBLE_DEVICES="0,1,2,3"
export PYTHONPATH=$PYTHONPATH:.↩
export FLAGS_fraction_of_gpu_memory_to_use=1.0↩
python_bin_dir="/opt/_internal/cpython-3.7.0/bin/"
alias python=$python_bin_dir"python3.7"
alias pip=$python_bin_dir"pip3.7"
alias ipython=$python_bin_dir"ipython3"
export LD_LIBRARY_PATH=/opt/_internal/cpython-3.7.0/lib:$LD_LIBRARY_PATH
export PYTHONPATH=$PYTHONPATH:.↩
ldconfig↩
export CUDA_VISIBLE_DEVICES="0,1,2,3"
export FLAGS_fraction_of_gpu_memory_to_use=1.0
python_bin_dir="/opt/_internal/cpython-3.7.0/bin/"
alias python=$python_bin_dir"python3.7"
alias pip=$python_bin_dir"pip3.7"
alias ipython=$python_bin_dir"ipython3"
export LD_LIBRARY_PATH=/opt/_internal/cpython-3.7.0/lib:$LD_LIBRARY_PATH
export PYTHONPATH=$PYTHONPATH:.
ldconfig
......@@ -20,7 +20,7 @@ import os
import time
import multiprocessing
import numpy as np
import glob
def set_paddle_flags(**kwargs):
for key, value in kwargs.items():
......@@ -79,9 +79,14 @@ def main():
init_model(config, eval_prog, exe)
blobs = reader_main(config, 'test')
imgs = next(blobs())
for img in imgs:
blobs = reader_main(config, 'test')()
infer_list = os.listdir(config['Global']['infer_img'])
max_img_num = len(infer_list)
if len(infer_list) == 0:
logger.info("Can not find img in infer_img dir.")
for i in range(max_img_num):
print("infer_img:",infer_list[i])
img = next(blobs)
predict = exe.run(program=eval_prog,
feed={"image": img},
fetch_list=fetch_varname_list,
......@@ -101,8 +106,8 @@ def main():
preds_text = preds_text.reshape(-1)
preds_text = char_ops.decode(preds_text)
print(preds)
print(preds_text)
print("\t index:",preds)
print("\t word :",preds_text)
# save for inference model
target_var = []
......
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册