diff --git a/ppocr/data/simple_dataset.py b/ppocr/data/simple_dataset.py index c8708a4ab94f1761551dc9ecbe17316ac0ab67f7..e9c3394cbe930d5169ae005e7582a2902e697b7e 100644 --- a/ppocr/data/simple_dataset.py +++ b/ppocr/data/simple_dataset.py @@ -14,7 +14,6 @@ import numpy as np import os import random -import traceback from paddle.io import Dataset from .imaug import transform, create_operators @@ -46,7 +45,6 @@ class SimpleDataSet(Dataset): self.seed = seed logger.info("Initialize indexs of datasets:%s" % label_file_list) self.data_lines = self.get_image_info_list(label_file_list, ratio_list) - self.check_data() self.data_idx_order_list = list(range(len(self.data_lines))) if self.mode == "train" and self.do_shuffle: self.shuffle_data_random() @@ -103,18 +101,25 @@ class SimpleDataSet(Dataset): def __getitem__(self, idx): file_idx = self.data_idx_order_list[idx] - data = self.data_lines[file_idx] + data_line = self.data_lines[file_idx] try: + data_line = data_line.decode('utf-8') + substr = data_line.strip("\n").split(self.delimiter) + file_name = substr[0] + label = substr[1] + img_path = os.path.join(self.data_dir, file_name) + data = {'img_path': img_path, 'label': label} + if not os.path.exists(img_path): + raise Exception("{} does not exist!".format(img_path)) with open(data['img_path'], 'rb') as f: img = f.read() data['image'] = img data['ext_data'] = self.get_ext_data() outs = transform(data, self.ops) - except: - error_meg = traceback.format_exc() + except Exception as e: self.logger.error( - "When parsing file {} and label {}, error happened with msg: {}".format( - data['img_path'],data['label'], error_meg)) + "When parsing line {}, error happened with msg: {}".format( + data_line, e)) outs = None if outs is None: # during evaluation, we should fix the idx to get same results for many times of evaluation. @@ -125,17 +130,3 @@ class SimpleDataSet(Dataset): def __len__(self): return len(self.data_idx_order_list) - - def check_data(self): - new_data_lines = [] - for data_line in self.data_lines: - data_line = data_line.decode('utf-8') - substr = data_line.strip("\n").strip("\r").split(self.delimiter) - file_name = substr[0] - label = substr[1] - img_path = os.path.join(self.data_dir, file_name) - if os.path.exists(img_path): - new_data_lines.append({'img_path': img_path, 'label': label}) - else: - self.logger.info("{} does not exist!".format(img_path)) - self.data_lines = new_data_lines \ No newline at end of file diff --git a/ppocr/modeling/architectures/distillation_model.py b/ppocr/modeling/architectures/distillation_model.py index 2b1d3aae3b7303a61b20db15df5ce4bd9bb7b235..1e95fe574433eaca6f322ff47c8547cc1a29a248 100644 --- a/ppocr/modeling/architectures/distillation_model.py +++ b/ppocr/modeling/architectures/distillation_model.py @@ -46,7 +46,7 @@ class DistillationModel(nn.Layer): pretrained = model_config.pop("pretrained") model = BaseModel(model_config) if pretrained is not None: - model = load_pretrained_params(model, pretrained) + load_pretrained_params(model, pretrained) if freeze_params: for param in model.parameters(): param.trainable = False