未验证 提交 aa9bedf3 编写于 作者: D dyning 提交者: GitHub

Merge pull request #4 from tink2123/develop

polish infer_rec and add ic15_dict
TrainReader: TrainReader:
reader_function: ppocr.data.rec.dataset_traversal,SimpleReader reader_function: ppocr.data.rec.dataset_traversal,SimpleReader
num_workers: 8 num_workers: 8
img_set_dir: . img_set_dir: ./train_data
label_file_path: ./train_data/hard_label.txt label_file_path: ./train_data/rec_gt_train.txt
EvalReader: EvalReader:
reader_function: ppocr.data.rec.dataset_traversal,SimpleReader reader_function: ppocr.data.rec.dataset_traversal,SimpleReader
img_set_dir: . img_set_dir: ./train_data
label_file_path: ./train_data/label_val_all.txt label_file_path: ./train_data/rec_gt_test.txt
TestReader: TestReader:
reader_function: ppocr.data.rec.dataset_traversal,SimpleReader reader_function: ppocr.data.rec.dataset_traversal,SimpleReader
......
...@@ -4,7 +4,7 @@ Global: ...@@ -4,7 +4,7 @@ Global:
epoch_num: 300 epoch_num: 300
log_smooth_window: 20 log_smooth_window: 20
print_batch_step: 10 print_batch_step: 10
save_model_dir: output save_model_dir: output_ic15
save_epoch_step: 3 save_epoch_step: 3
eval_batch_step: 2000 eval_batch_step: 2000
train_batch_size_per_card: 256 train_batch_size_per_card: 256
...@@ -12,11 +12,12 @@ Global: ...@@ -12,11 +12,12 @@ Global:
image_shape: [3, 32, 100] image_shape: [3, 32, 100]
max_text_length: 25 max_text_length: 25
character_type: ch character_type: ch
character_dict_path: ./ppocr/utils/ppocr_keys_v1.txt character_dict_path: ./ppocr/utils/ic15_dict.txt
loss_type: ctc loss_type: ctc
reader_yml: ./configs/rec/rec_chinese_reader.yml reader_yml: ./configs/rec/rec_icdar15_reader.yml
pretrain_weights: pretrain_weights: ./pretrain_models/CRNN/best_accuracy
checkpoints:
save_inference_dir:
Architecture: Architecture:
function: ppocr.modeling.architectures.rec_model,RecModel function: ppocr.modeling.architectures.rec_model,RecModel
......
...@@ -22,6 +22,7 @@ import string ...@@ -22,6 +22,7 @@ import string
import lmdb import lmdb
from ppocr.utils.utility import initial_logger from ppocr.utils.utility import initial_logger
from ppocr.utils.utility import get_image_file_list
logger = initial_logger() logger = initial_logger()
from .img_tools import process_image, get_img_data from .img_tools import process_image, get_img_data
...@@ -143,8 +144,9 @@ class SimpleReader(object): ...@@ -143,8 +144,9 @@ class SimpleReader(object):
self.num_workers = 1 self.num_workers = 1
else: else:
self.num_workers = params['num_workers'] self.num_workers = params['num_workers']
self.img_set_dir = params['img_set_dir'] if params['mode'] != 'test':
self.label_file_path = params['label_file_path'] self.img_set_dir = params['img_set_dir']
self.label_file_path = params['label_file_path']
self.char_ops = params['char_ops'] self.char_ops = params['char_ops']
self.image_shape = params['image_shape'] self.image_shape = params['image_shape']
self.loss_type = params['loss_type'] self.loss_type = params['loss_type']
...@@ -164,29 +166,34 @@ class SimpleReader(object): ...@@ -164,29 +166,34 @@ class SimpleReader(object):
def sample_iter_reader(): def sample_iter_reader():
if self.mode == 'test': if self.mode == 'test':
print("infer_img:", self.infer_img) image_file_list = get_image_file_list(self.infer_img)
img = cv2.imread(self.infer_img) for single_img in image_file_list:
norm_img = process_image(img, self.image_shape) img = cv2.imread(single_img)
yield norm_img if img.shape[-1]==1 or len(list(img.shape))==2:
with open(self.label_file_path, "rb") as fin: img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
label_infor_list = fin.readlines() norm_img = process_image(img, self.image_shape)
img_num = len(label_infor_list) yield norm_img
img_id_list = list(range(img_num)) else:
random.shuffle(img_id_list) with open(self.label_file_path, "rb") as fin:
for img_id in range(process_id, img_num, self.num_workers): label_infor_list = fin.readlines()
label_infor = label_infor_list[img_id_list[img_id]] img_num = len(label_infor_list)
substr = label_infor.decode('utf-8').strip("\n").split("\t") img_id_list = list(range(img_num))
img_path = self.img_set_dir + "/" + substr[0] random.shuffle(img_id_list)
img = cv2.imread(img_path) for img_id in range(process_id, img_num, self.num_workers):
if img is None: label_infor = label_infor_list[img_id_list[img_id]]
continue substr = label_infor.decode('utf-8').strip("\n").split("\t")
label = substr[1] img_path = self.img_set_dir + "/" + substr[0]
outs = process_image(img, self.image_shape, label, img = cv2.imread(img_path)
self.char_ops, self.loss_type, if img is None:
self.max_text_length) logger.info("{} does not exist!".format(img_path))
if outs is None: continue
continue label = substr[1]
yield outs outs = process_image(img, self.image_shape, label,
self.char_ops, self.loss_type,
self.max_text_length)
if outs is None:
continue
yield outs
def batch_iter_reader(): def batch_iter_reader():
batch_outs = [] batch_outs = []
...@@ -198,4 +205,6 @@ class SimpleReader(object): ...@@ -198,4 +205,6 @@ class SimpleReader(object):
if len(batch_outs) != 0: if len(batch_outs) != 0:
yield batch_outs yield batch_outs
return batch_iter_reader if self.mode != 'test':
return batch_iter_reader
return sample_iter_reader
a
b
c
d
e
f
g
h
i
j
k
l
m
n
o
p
q
r
s
t
u
v
w
x
y
z
0
1
2
3
4
5
6
7
8
9
#. /paddle/set_env.sh↩
export CUDA_VISIBLE_DEVICES="0,1,2,3"
export PYTHONPATH=$PYTHONPATH:.↩
export FLAGS_fraction_of_gpu_memory_to_use=1.0↩
python_bin_dir="/opt/_internal/cpython-3.7.0/bin/"
alias python=$python_bin_dir"python3.7"
alias pip=$python_bin_dir"pip3.7"
alias ipython=$python_bin_dir"ipython3"
export LD_LIBRARY_PATH=/opt/_internal/cpython-3.7.0/lib:$LD_LIBRARY_PATH
export PYTHONPATH=$PYTHONPATH:.↩
ldconfig↩
...@@ -80,7 +80,7 @@ def main(): ...@@ -80,7 +80,7 @@ def main():
metrics = eval_det_run(exe, config, eval_info_dict, "test") metrics = eval_det_run(exe, config, eval_info_dict, "test")
else: else:
reader_type = config['Global']['reader_yml'] reader_type = config['Global']['reader_yml']
if "chinese" in reader_type: if "benchmark" not in reader_type:
eval_reader = reader_main(config=config, mode="eval") eval_reader = reader_main(config=config, mode="eval")
eval_info_dict = {'program': eval_program, \ eval_info_dict = {'program': eval_program, \
'reader': eval_reader, \ 'reader': eval_reader, \
......
...@@ -21,7 +21,6 @@ import time ...@@ -21,7 +21,6 @@ import time
import multiprocessing import multiprocessing
import numpy as np import numpy as np
def set_paddle_flags(**kwargs): def set_paddle_flags(**kwargs):
for key, value in kwargs.items(): for key, value in kwargs.items():
if os.environ.get(key, None) is None: if os.environ.get(key, None) is None:
...@@ -47,7 +46,7 @@ from ppocr.data.reader_main import reader_main ...@@ -47,7 +46,7 @@ from ppocr.data.reader_main import reader_main
from ppocr.utils.save_load import init_model from ppocr.utils.save_load import init_model
from ppocr.utils.character import CharacterOps from ppocr.utils.character import CharacterOps
from ppocr.utils.utility import create_module from ppocr.utils.utility import create_module
from ppocr.utils.utility import get_image_file_list
logger = initial_logger() logger = initial_logger()
...@@ -79,9 +78,15 @@ def main(): ...@@ -79,9 +78,15 @@ def main():
init_model(config, eval_prog, exe) init_model(config, eval_prog, exe)
blobs = reader_main(config, 'test') blobs = reader_main(config, 'test')()
imgs = next(blobs()) infer_img = config['TestReader']['infer_img']
for img in imgs: infer_list = get_image_file_list(infer_img)
max_img_num = len(infer_list)
if len(infer_list) == 0:
logger.info("Can not find img in infer_img dir.")
for i in range(max_img_num):
print("infer_img:",infer_list[i])
img = next(blobs)
predict = exe.run(program=eval_prog, predict = exe.run(program=eval_prog,
feed={"image": img}, feed={"image": img},
fetch_list=fetch_varname_list, fetch_list=fetch_varname_list,
...@@ -101,8 +106,8 @@ def main(): ...@@ -101,8 +106,8 @@ def main():
preds_text = preds_text.reshape(-1) preds_text = preds_text.reshape(-1)
preds_text = char_ops.decode(preds_text) preds_text = char_ops.decode(preds_text)
print(preds) print("\t index:",preds)
print(preds_text) print("\t word :",preds_text)
# save for inference model # save for inference model
target_var = [] target_var = []
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册