From 60b34a5192d367db5477adcf293f6785390ae62d Mon Sep 17 00:00:00 2001 From: littletomatodonkey Date: Sun, 10 May 2020 09:00:43 +0000 Subject: [PATCH] add delimiter reader --- ppcls/data/reader.py | 45 ++++++++++++++++++++++---------------------- 1 file changed, 22 insertions(+), 23 deletions(-) diff --git a/ppcls/data/reader.py b/ppcls/data/reader.py index 5bf83c21..20072db9 100755 --- a/ppcls/data/reader.py +++ b/ppcls/data/reader.py @@ -1,27 +1,25 @@ -#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. # -#Licensed under the Apache License, Version 2.0 (the "License"); -#you may not use this file except in compliance with the License. -#You may obtain a copy of the License at +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # -#Unless required by applicable law or agreed to in writing, software -#distributed under the License is distributed on an "AS IS" BASIS, -#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -#See the License for the specific language governing permissions and -#limitations under the License. +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. -import cv2 import numpy as np import os import signal - +import imghdr import paddle from . import imaug from .imaug import transform -from .imaug import MixupOperator from ppcls.utils import logger trainers_num = int(os.environ.get('PADDLE_TRAINERS_NUM', 1)) @@ -35,7 +33,7 @@ class ModeException(Exception): def __init__(self, message='', mode=''): message += "\nOnly the following 3 modes are supported: " \ - "train, valid, test. Given mode is {}".format(mode) + "train, valid, test. Given mode is {}".format(mode) super(ModeException, self).__init__(message) @@ -46,10 +44,10 @@ class SampleNumException(Exception): def __init__(self, message='', sample_num=0, batch_size=1): message += "\nError: The number of the whole data ({}) " \ - "is smaller than the batch_size ({}), and drop_last " \ - "is turnning on, so nothing will feed in program, " \ - "Terminated now. Please reset batch_size to a smaller " \ - "number or feed more data!".format(sample_num, batch_size) + "is smaller than the batch_size ({}), and drop_last " \ + "is turnning on, so nothing will feed in program, " \ + "Terminated now. Please reset batch_size to a smaller " \ + "number or feed more data!".format(sample_num, batch_size) super(SampleNumException, self).__init__(message) @@ -80,12 +78,12 @@ def check_params(params): data_dir = params.get('data_dir', '') assert os.path.isdir(data_dir), \ - "{} doesn't exist, please check datadir path".format(data_dir) + "{} doesn't exist, please check datadir path".format(data_dir) if params['mode'] != 'test': file_list = params.get('file_list', '') assert os.path.isfile(file_list), \ - "{} doesn't exist, please check file list path".format(file_list) + "{} doesn't exist, please check file list path".format(file_list) def create_file_list(params): @@ -176,8 +174,8 @@ def partial_reader(params, full_lines, part_id=0, part_num=1): part_id(int): part index of the current partial data part_num(int): part num of the dataset """ - assert part_id < part_num, ("part_num: {} should be larger " \ - "than part_id: {}".format(part_num, part_id)) + assert part_id < part_num, ("part_num: {} should be larger " + "than part_id: {}".format(part_num, part_id)) full_lines = full_lines[part_id::part_num] @@ -187,8 +185,9 @@ def partial_reader(params, full_lines, part_id=0, part_num=1): def reader(): ops = create_operators(params['transforms']) + delimiter = params.get('delimiter', ' ') for line in full_lines: - img_path, label = line.split() + img_path, label = line.split(delimiter) img_path = os.path.join(params['data_dir'], img_path) with open(img_path, 'rb') as f: img = f.read() @@ -220,7 +219,7 @@ def mp_reader(params): def term_mp(sig_num, frame): - """ kill all child processes + """ kill all child processes """ pid = os.getpid() pgid = os.getpgid(os.getpid()) -- GitLab