diff --git a/ppocr/data/det/dataset_traversal.py b/ppocr/data/det/dataset_traversal.py index 92b9aac4f541a754fc218761e28439ea2134a502..c4ad5ced4f0240fa776afce8b7bf3e34310bd186 100644 --- a/ppocr/data/det/dataset_traversal.py +++ b/ppocr/data/det/dataset_traversal.py @@ -36,17 +36,17 @@ class TrainReader(object): "absence process_function in Reader" self.process = create_module(params['process_function'])(params) - def __call__(self, process_id): + def __call__(self, process_id): + with open(self.label_file_path, "rb") as fin: + label_infor_list = fin.readlines() + img_num = len(label_infor_list) + img_id_list = list(range(img_num)) + if sys.platform == "win32" and self.num_workers != 1: + print("multiprocess is not fully compatible with Windows." + "num_workers will be 1.") + self.num_workers = 1 def sample_iter_reader(): - with open(self.label_file_path, "rb") as fin: - label_infor_list = fin.readlines() - img_num = len(label_infor_list) - img_id_list = list(range(img_num)) random.shuffle(img_id_list) - if sys.platform == "win32" and self.num_workers != 1: - print("multiprocess is not fully compatible with Windows." - "num_workers will be 1.") - self.num_workers = 1 for img_id in range(process_id, img_num, self.num_workers): label_infor = label_infor_list[img_id_list[img_id]] outs = self.process(label_infor) diff --git a/ppocr/utils/save_load.py b/ppocr/utils/save_load.py index 4d23f62656f5561dd93f40fa97a3f7874e4b2040..74a200f5e5661ecfe1409290871a931bdf18e99d 100755 --- a/ppocr/utils/save_load.py +++ b/ppocr/utils/save_load.py @@ -21,7 +21,6 @@ import os import shutil import tempfile -import paddle import paddle.fluid as fluid from .utility import initial_logger @@ -113,14 +112,12 @@ def init_model(config, program, exe): path = checkpoints fluid.load(program, path, exe) logger.info("Finish initing model from {}".format(path)) - return pretrain_weights = config['Global'].get('pretrain_weights') if pretrain_weights: path = pretrain_weights load_params(exe, program, path) logger.info("Finish initing model from {}".format(path)) - return def save_model(program, model_path):