diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 15d8e516abaffd444f20057c1099658c5aa215d5..1584bc76a9dd8ddff9d05a8cb693bcbd2e09fcde 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,35 +1,35 @@ -- repo: https://github.com/PaddlePaddle/mirrors-yapf.git - sha: 0d79c0c469bab64f7229c9aca2b1186ef47f0e37 - hooks: - - id: yapf - files: \.py$ - -- repo: https://github.com/pre-commit/mirrors-autopep8 - rev: v1.5 - hooks: - - id: autopep8 - -- repo: https://github.com/Lucas-C/pre-commit-hooks - sha: v1.0.1 - hooks: - - id: forbid-crlf - files: \.(md|yml)$ - - id: remove-crlf - files: \.(md|yml)$ - - id: forbid-tabs - files: \.(md|yml)$ - - id: remove-tabs - files: \.(md|yml)$ - -- repo: https://github.com/pre-commit/pre-commit-hooks - rev: v2.5.0 - hooks: - - id: check-yaml - - id: check-merge-conflict - - id: detect-private-key - files: (?!.*paddle)^.*$ - - id: end-of-file-fixer - files: \.(md|yml)$ - - id: trailing-whitespace - files: \.(md|yml)$ - - id: check-case-conflict +- repo: https://github.com/PaddlePaddle/mirrors-yapf.git + sha: 0d79c0c469bab64f7229c9aca2b1186ef47f0e37 + hooks: + - id: yapf + files: \.py$ +- repo: https://github.com/pre-commit/pre-commit-hooks + sha: a11d9314b22d8f8c7556443875b731ef05965464 + hooks: + - id: check-merge-conflict + - id: check-symlinks + - id: detect-private-key + files: (?!.*paddle)^.*$ + - id: end-of-file-fixer + files: \.md$ + - id: trailing-whitespace + files: \.md$ +- repo: https://github.com/Lucas-C/pre-commit-hooks + sha: v1.0.1 + hooks: + - id: forbid-crlf + files: \.md$ + - id: remove-crlf + files: \.md$ + - id: forbid-tabs + files: \.md$ + - id: remove-tabs + files: \.md$ +- repo: local + hooks: + - id: clang-format + name: clang-format + description: Format files with ClangFormat + entry: bash .clang_format.hook -i + language: system + files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|cuh|proto)$ diff --git a/configs/MobileNetV3/MobileNetV3_large_x1_0.yaml b/configs/MobileNetV3/MobileNetV3_large_x1_0.yaml index e924d4ea9789bb7bbab91723d0166cfca2308358..6f9e5597f42abbe1398cb9c8a6e16ec246f26c88 100644 --- a/configs/MobileNetV3/MobileNetV3_large_x1_0.yaml +++ b/configs/MobileNetV3/MobileNetV3_large_x1_0.yaml @@ -54,7 +54,7 @@ TRAIN: VALID: - batch_size: 32 + batch_size: 1024 num_workers: 4 file_list: "./dataset/ILSVRC2012/val_list.txt" data_dir: "./dataset/ILSVRC2012/" diff --git a/ppcls/data/reader.py b/ppcls/data/reader.py index cd0bf1350129f196324bf0b5c010c07e3239c45c..17035a031ac18ab65828b529c84df439794fe7cd 100755 --- a/ppcls/data/reader.py +++ b/ppcls/data/reader.py @@ -17,7 +17,7 @@ import imghdr import os import signal -from paddle.reader import multiprocess_reader +from paddle.io import Dataset, DataLoader, DistributedBatchSampler from . import imaug from .imaug import transform @@ -109,7 +109,6 @@ def create_file_list(params): def shuffle_lines(full_lines, seed=None): """ random shuffle lines - Args: full_lines(list): seed(int): random seed @@ -135,12 +134,8 @@ def get_file_list(params): with open(params['file_list']) as flist: full_lines = [line.strip() for line in flist] - full_lines = shuffle_lines(full_lines, params["shuffle_seed"]) - - # use only partial data for each trainer in distributed training - if params['mode'] == 'train': - img_per_trainer = len(full_lines) // trainers_num - full_lines = full_lines[trainer_id::trainers_num][:img_per_trainer] + if params["mode"] == "train": + full_lines = shuffle_lines(full_lines, seed=params['shuffle_seed']) return full_lines @@ -165,60 +160,6 @@ def create_operators(params): return ops -def partial_reader(params, full_lines, part_id=0, part_num=1): - """ - create a reader with partial data - - Args: - params(dict): - full_lines: label list - part_id(int): part index of the current partial data - part_num(int): part num of the dataset - """ - assert part_id < part_num, ("part_num: {} should be larger " - "than part_id: {}".format(part_num, part_id)) - - full_lines = full_lines[part_id::part_num] - - batch_size = int(params['batch_size']) // trainers_num - if params['mode'] != "test" and len(full_lines) < batch_size: - raise SampleNumException('', len(full_lines), batch_size) - - def reader(): - ops = create_operators(params['transforms']) - delimiter = params.get('delimiter', ' ') - for line in full_lines: - img_path, label = line.split(delimiter) - img_path = os.path.join(params['data_dir'], img_path) - with open(img_path, 'rb') as f: - img = f.read() - yield (transform(img, ops), int(label)) - - return reader - - -def mp_reader(params): - """ - multiprocess reader - - Args: - params(dict): - """ - check_params(params) - - full_lines = get_file_list(params) - if params["mode"] == "train": - full_lines = shuffle_lines(full_lines, seed=None) - - part_num = 1 if 'num_workers' not in params else params['num_workers'] - - readers = [] - for part_id in range(part_num): - readers.append(partial_reader(params, full_lines, part_id, part_num)) - - return multiprocess_reader(readers, use_pipe=False) - - def term_mp(sig_num, frame): """ kill all child processes """ @@ -227,6 +168,29 @@ def term_mp(sig_num, frame): logger.info("main proc {} exit, kill process group " "{}".format(pid, pgid)) os.killpg(pgid, signal.SIGKILL) + return + + +class CommonDataset(Dataset): + def __init__(self, params): + self.params = params + self.mode = params.get("mode", "train") + self.full_lines = get_file_list(params) + self.delimiter = params.get('delimiter', ' ') + self.ops = create_operators(params['transforms']) + self.num_samples = len(self.full_lines) + return + + def __getitem__(self, idx): + line = self.full_lines[idx] + img_path, label = line.split(self.delimiter) + img_path = os.path.join(self.params['data_dir'], img_path) + with open(img_path, 'rb') as f: + img = f.read() + return (transform(img, self.ops), int(label)) + + def __len__(self): + return self.num_samples class Reader: @@ -242,7 +206,7 @@ class Reader: the specific reader """ - def __init__(self, config, mode='train', seed=None): + def __init__(self, config, mode='train', places=None): try: self.params = config[mode.upper()] except KeyError: @@ -250,27 +214,58 @@ class Reader: use_mix = config.get('use_mix') self.params['mode'] = mode - if seed is not None: - self.params['shuffle_seed'] = seed + self.shuffle = mode == "train" + + self.collate_fn = None self.batch_ops = [] if use_mix and mode == "train": self.batch_ops = create_operators(self.params['mix']) + self.collate_fn = self.mix_collate_fn + + self.places = places + + def mix_collate_fn(self, batch): + batch = transform(batch, self.batch_ops) + # batch each field + slots = [] + for items in batch: + for i, item in enumerate(items): + if len(slots) < len(items): + slots.append([item]) + else: + slots[i].append(item) + + return [np.stack(slot, axis=0) for slot in slots] def __call__(self): batch_size = int(self.params['batch_size']) // trainers_num - def wrapper(): - reader = mp_reader(self.params) - batch = [] - for idx, sample in enumerate(reader()): - img, label = sample - batch.append((img, label)) - if (idx + 1) % batch_size == 0: - batch = transform(batch, self.batch_ops) - yield batch - batch = [] - - return wrapper + dataset = CommonDataset(self.params) + + if self.params['mode'] == "train": + batch_sampler = DistributedBatchSampler( + dataset, + batch_size=batch_size, + shuffle=self.shuffle, + drop_last=True) + loader = DataLoader( + dataset, + batch_sampler=batch_sampler, + collate_fn=self.collate_fn, + places=self.places, + return_list=True, + num_workers=self.params["num_workers"]) + else: + loader = DataLoader( + dataset, + places=self.places, + batch_size=batch_size, + drop_last=False, + return_list=True, + shuffle=False, + num_workers=self.params["num_workers"]) + + return loader signal.signal(signal.SIGINT, term_mp) diff --git a/ppcls/modeling/loss.py b/ppcls/modeling/loss.py index 706fe04adf97f6d448cd4a76f98ecd8977739aac..0b705df481f26361c5750267f046852665ef038d 100644 --- a/ppcls/modeling/loss.py +++ b/ppcls/modeling/loss.py @@ -49,7 +49,6 @@ class Loss(object): input = -F.log_softmax(input, axis=-1) cost = paddle.reduce_sum(target * input, dim=-1) else: - # softmax_out = F.softmax(input) cost = F.cross_entropy(input=input, label=target) avg_cost = paddle.mean(cost) return avg_cost diff --git a/tools/eval.py b/tools/eval.py index 06349ca96aec7f6936f3753679b774bf3304926c..da37a2c38256677db602c8d46b3f85eeab917490 100644 --- a/tools/eval.py +++ b/tools/eval.py @@ -63,10 +63,9 @@ def main(args): net = program.create_model(config.ARCHITECTURE, config.classes_num) net = paddle.DataParallel(net, strategy) init_model(config, net, optimizer=None) - valid_dataloader = program.create_dataloader() - valid_reader = Reader(config, 'valid')() - valid_dataloader.set_sample_list_generator(valid_reader, place) + valid_dataloader = Reader(config, 'valid', places=place)() net.eval() + top1_acc = program.run(valid_dataloader, config, net, None, None, 0, 'valid') diff --git a/tools/train.py b/tools/train.py index 5061f4f0be93c81e78c6286afc019aaad1c782c1..32af634966bb7e76991a0b9e29dbcb36168b37fe 100644 --- a/tools/train.py +++ b/tools/train.py @@ -23,7 +23,6 @@ __dir__ = os.path.dirname(os.path.abspath(__file__)) sys.path.append(__dir__) sys.path.append(os.path.abspath(os.path.join(__dir__, '..'))) - import paddle from paddle.distributed import ParallelEnv @@ -33,6 +32,7 @@ from ppcls.utils.save_load import init_model, save_model from ppcls.utils import logger import program + def parse_args(): parser = argparse.ArgumentParser("PaddleClas train script") parser.add_argument( @@ -78,16 +78,13 @@ def main(args): # load model from checkpoint or pretrained model init_model(config, net, optimizer) - train_dataloader = program.create_dataloader() - train_reader = Reader(config, 'train')() - train_dataloader.set_sample_list_generator(train_reader, place) + train_dataloader = Reader(config, 'train', places=place)() - if config.validate: - valid_dataloader = program.create_dataloader() - valid_reader = Reader(config, 'valid')() - valid_dataloader.set_sample_list_generator(valid_reader, place) + if config.validate and ParallelEnv().local_rank == 0: + valid_dataloader = Reader(config, 'valid', places=place)() best_top1_acc = 0.0 # best top1 acc record + best_top1_epoch = 0 for epoch_id in range(config.epochs): net.train() # 1. train with train dataset @@ -98,18 +95,18 @@ def main(args): # 2. validate with validate dataset if config.validate and epoch_id % config.valid_interval == 0: net.eval() - top1_acc = program.run(valid_dataloader, config, net, None, None, - epoch_id, 'valid') + top1_acc = program.run(valid_dataloader, config, net, None, + None, epoch_id, 'valid') if top1_acc > best_top1_acc: best_top1_acc = top1_acc - message = "The best top1 acc {:.5f}, in epoch: {:d}".format( - best_top1_acc, epoch_id) - logger.info("{:s}".format(logger.coloring(message, "RED"))) + best_top1_epoch = epoch_id if epoch_id % config.save_interval == 0: - model_path = os.path.join(config.model_save_dir, config.ARCHITECTURE["name"]) save_model(net, optimizer, model_path, "best_model") + message = "The best top1 acc {:.5f}, in epoch: {:d}".format( + best_top1_acc, best_top1_epoch) + logger.info("{:s}".format(logger.coloring(message, "RED"))) # 3. save the persistable model if epoch_id % config.save_interval == 0: