From 3ecefa486a5c4055d80089f998e6030fa7644b45 Mon Sep 17 00:00:00 2001 From: WuHaobo Date: Mon, 18 May 2020 09:51:17 +0800 Subject: [PATCH] refine reader to support paddle1.8 API and fix the code style --- ppcls/data/reader.py | 45 ++++++++++++++++++++++---------------------- 1 file changed, 22 insertions(+), 23 deletions(-) diff --git a/ppcls/data/reader.py b/ppcls/data/reader.py index 5bf83c21..92199630 100755 --- a/ppcls/data/reader.py +++ b/ppcls/data/reader.py @@ -1,27 +1,26 @@ -#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. # -#Licensed under the Apache License, Version 2.0 (the "License"); -#you may not use this file except in compliance with the License. -#You may obtain a copy of the License at +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # -#Unless required by applicable law or agreed to in writing, software -#distributed under the License is distributed on an "AS IS" BASIS, -#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -#See the License for the specific language governing permissions and -#limitations under the License. +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. -import cv2 import numpy as np +import imghdr import os import signal -import paddle +from paddle.fluid.io import multiprocess_reader from . import imaug from .imaug import transform -from .imaug import MixupOperator from ppcls.utils import logger trainers_num = int(os.environ.get('PADDLE_TRAINERS_NUM', 1)) @@ -35,7 +34,7 @@ class ModeException(Exception): def __init__(self, message='', mode=''): message += "\nOnly the following 3 modes are supported: " \ - "train, valid, test. Given mode is {}".format(mode) + "train, valid, test. Given mode is {}".format(mode) super(ModeException, self).__init__(message) @@ -46,10 +45,10 @@ class SampleNumException(Exception): def __init__(self, message='', sample_num=0, batch_size=1): message += "\nError: The number of the whole data ({}) " \ - "is smaller than the batch_size ({}), and drop_last " \ - "is turnning on, so nothing will feed in program, " \ - "Terminated now. Please reset batch_size to a smaller " \ - "number or feed more data!".format(sample_num, batch_size) + "is smaller than the batch_size ({}), and drop_last " \ + "is turnning on, so nothing will feed in program, " \ + "Terminated now. Please reset batch_size to a smaller " \ + "number or feed more data!".format(sample_num, batch_size) super(SampleNumException, self).__init__(message) @@ -80,12 +79,12 @@ def check_params(params): data_dir = params.get('data_dir', '') assert os.path.isdir(data_dir), \ - "{} doesn't exist, please check datadir path".format(data_dir) + "{} doesn't exist, please check datadir path".format(data_dir) if params['mode'] != 'test': file_list = params.get('file_list', '') assert os.path.isfile(file_list), \ - "{} doesn't exist, please check file list path".format(file_list) + "{} doesn't exist, please check file list path".format(file_list) def create_file_list(params): @@ -176,8 +175,8 @@ def partial_reader(params, full_lines, part_id=0, part_num=1): part_id(int): part index of the current partial data part_num(int): part num of the dataset """ - assert part_id < part_num, ("part_num: {} should be larger " \ - "than part_id: {}".format(part_num, part_id)) + assert part_id < part_num, ("part_num: {} should be larger " + "than part_id: {}".format(part_num, part_id)) full_lines = full_lines[part_id::part_num] @@ -216,11 +215,11 @@ def mp_reader(params): for part_id in range(part_num): readers.append(partial_reader(params, full_lines, part_id, part_num)) - return paddle.reader.multiprocess_reader(readers, use_pipe=False) + return multiprocess_reader(readers, use_pipe=False) def term_mp(sig_num, frame): - """ kill all child processes + """ kill all child processes """ pid = os.getpid() pgid = os.getpgid(os.getpid()) -- GitLab