未验证 提交 a4753fc3 编写于 作者: L littletomatodonkey 提交者: GitHub

Merge pull request #296 from littletomatodonkey/dyg/add_dataloader

add dataloader inferface
...@@ -3,33 +3,33 @@ ...@@ -3,33 +3,33 @@
hooks: hooks:
- id: yapf - id: yapf
files: \.py$ files: \.py$
- repo: https://github.com/pre-commit/pre-commit-hooks
- repo: https://github.com/pre-commit/mirrors-autopep8 sha: a11d9314b22d8f8c7556443875b731ef05965464
rev: v1.5
hooks: hooks:
- id: autopep8 - id: check-merge-conflict
- id: check-symlinks
- id: detect-private-key
files: (?!.*paddle)^.*$
- id: end-of-file-fixer
files: \.md$
- id: trailing-whitespace
files: \.md$
- repo: https://github.com/Lucas-C/pre-commit-hooks - repo: https://github.com/Lucas-C/pre-commit-hooks
sha: v1.0.1 sha: v1.0.1
hooks: hooks:
- id: forbid-crlf - id: forbid-crlf
files: \.(md|yml)$ files: \.md$
- id: remove-crlf - id: remove-crlf
files: \.(md|yml)$ files: \.md$
- id: forbid-tabs - id: forbid-tabs
files: \.(md|yml)$ files: \.md$
- id: remove-tabs - id: remove-tabs
files: \.(md|yml)$ files: \.md$
- repo: local
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v2.5.0
hooks: hooks:
- id: check-yaml - id: clang-format
- id: check-merge-conflict name: clang-format
- id: detect-private-key description: Format files with ClangFormat
files: (?!.*paddle)^.*$ entry: bash .clang_format.hook -i
- id: end-of-file-fixer language: system
files: \.(md|yml)$ files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|cuh|proto)$
- id: trailing-whitespace
files: \.(md|yml)$
- id: check-case-conflict
...@@ -54,7 +54,7 @@ TRAIN: ...@@ -54,7 +54,7 @@ TRAIN:
VALID: VALID:
batch_size: 32 batch_size: 1024
num_workers: 4 num_workers: 4
file_list: "./dataset/ILSVRC2012/val_list.txt" file_list: "./dataset/ILSVRC2012/val_list.txt"
data_dir: "./dataset/ILSVRC2012/" data_dir: "./dataset/ILSVRC2012/"
......
...@@ -17,7 +17,7 @@ import imghdr ...@@ -17,7 +17,7 @@ import imghdr
import os import os
import signal import signal
from paddle.reader import multiprocess_reader from paddle.io import Dataset, DataLoader, DistributedBatchSampler
from . import imaug from . import imaug
from .imaug import transform from .imaug import transform
...@@ -109,7 +109,6 @@ def create_file_list(params): ...@@ -109,7 +109,6 @@ def create_file_list(params):
def shuffle_lines(full_lines, seed=None): def shuffle_lines(full_lines, seed=None):
""" """
random shuffle lines random shuffle lines
Args: Args:
full_lines(list): full_lines(list):
seed(int): random seed seed(int): random seed
...@@ -135,12 +134,8 @@ def get_file_list(params): ...@@ -135,12 +134,8 @@ def get_file_list(params):
with open(params['file_list']) as flist: with open(params['file_list']) as flist:
full_lines = [line.strip() for line in flist] full_lines = [line.strip() for line in flist]
full_lines = shuffle_lines(full_lines, params["shuffle_seed"]) if params["mode"] == "train":
full_lines = shuffle_lines(full_lines, seed=params['shuffle_seed'])
# use only partial data for each trainer in distributed training
if params['mode'] == 'train':
img_per_trainer = len(full_lines) // trainers_num
full_lines = full_lines[trainer_id::trainers_num][:img_per_trainer]
return full_lines return full_lines
...@@ -165,60 +160,6 @@ def create_operators(params): ...@@ -165,60 +160,6 @@ def create_operators(params):
return ops return ops
def partial_reader(params, full_lines, part_id=0, part_num=1):
"""
create a reader with partial data
Args:
params(dict):
full_lines: label list
part_id(int): part index of the current partial data
part_num(int): part num of the dataset
"""
assert part_id < part_num, ("part_num: {} should be larger "
"than part_id: {}".format(part_num, part_id))
full_lines = full_lines[part_id::part_num]
batch_size = int(params['batch_size']) // trainers_num
if params['mode'] != "test" and len(full_lines) < batch_size:
raise SampleNumException('', len(full_lines), batch_size)
def reader():
ops = create_operators(params['transforms'])
delimiter = params.get('delimiter', ' ')
for line in full_lines:
img_path, label = line.split(delimiter)
img_path = os.path.join(params['data_dir'], img_path)
with open(img_path, 'rb') as f:
img = f.read()
yield (transform(img, ops), int(label))
return reader
def mp_reader(params):
"""
multiprocess reader
Args:
params(dict):
"""
check_params(params)
full_lines = get_file_list(params)
if params["mode"] == "train":
full_lines = shuffle_lines(full_lines, seed=None)
part_num = 1 if 'num_workers' not in params else params['num_workers']
readers = []
for part_id in range(part_num):
readers.append(partial_reader(params, full_lines, part_id, part_num))
return multiprocess_reader(readers, use_pipe=False)
def term_mp(sig_num, frame): def term_mp(sig_num, frame):
""" kill all child processes """ kill all child processes
""" """
...@@ -227,6 +168,29 @@ def term_mp(sig_num, frame): ...@@ -227,6 +168,29 @@ def term_mp(sig_num, frame):
logger.info("main proc {} exit, kill process group " logger.info("main proc {} exit, kill process group "
"{}".format(pid, pgid)) "{}".format(pid, pgid))
os.killpg(pgid, signal.SIGKILL) os.killpg(pgid, signal.SIGKILL)
return
class CommonDataset(Dataset):
def __init__(self, params):
self.params = params
self.mode = params.get("mode", "train")
self.full_lines = get_file_list(params)
self.delimiter = params.get('delimiter', ' ')
self.ops = create_operators(params['transforms'])
self.num_samples = len(self.full_lines)
return
def __getitem__(self, idx):
line = self.full_lines[idx]
img_path, label = line.split(self.delimiter)
img_path = os.path.join(self.params['data_dir'], img_path)
with open(img_path, 'rb') as f:
img = f.read()
return (transform(img, self.ops), int(label))
def __len__(self):
return self.num_samples
class Reader: class Reader:
...@@ -242,7 +206,7 @@ class Reader: ...@@ -242,7 +206,7 @@ class Reader:
the specific reader the specific reader
""" """
def __init__(self, config, mode='train', seed=None): def __init__(self, config, mode='train', places=None):
try: try:
self.params = config[mode.upper()] self.params = config[mode.upper()]
except KeyError: except KeyError:
...@@ -250,27 +214,58 @@ class Reader: ...@@ -250,27 +214,58 @@ class Reader:
use_mix = config.get('use_mix') use_mix = config.get('use_mix')
self.params['mode'] = mode self.params['mode'] = mode
if seed is not None: self.shuffle = mode == "train"
self.params['shuffle_seed'] = seed
self.collate_fn = None
self.batch_ops = [] self.batch_ops = []
if use_mix and mode == "train": if use_mix and mode == "train":
self.batch_ops = create_operators(self.params['mix']) self.batch_ops = create_operators(self.params['mix'])
self.collate_fn = self.mix_collate_fn
def __call__(self): self.places = places
batch_size = int(self.params['batch_size']) // trainers_num
def wrapper(): def mix_collate_fn(self, batch):
reader = mp_reader(self.params)
batch = []
for idx, sample in enumerate(reader()):
img, label = sample
batch.append((img, label))
if (idx + 1) % batch_size == 0:
batch = transform(batch, self.batch_ops) batch = transform(batch, self.batch_ops)
yield batch # batch each field
batch = [] slots = []
for items in batch:
for i, item in enumerate(items):
if len(slots) < len(items):
slots.append([item])
else:
slots[i].append(item)
return wrapper return [np.stack(slot, axis=0) for slot in slots]
def __call__(self):
batch_size = int(self.params['batch_size']) // trainers_num
dataset = CommonDataset(self.params)
if self.params['mode'] == "train":
batch_sampler = DistributedBatchSampler(
dataset,
batch_size=batch_size,
shuffle=self.shuffle,
drop_last=True)
loader = DataLoader(
dataset,
batch_sampler=batch_sampler,
collate_fn=self.collate_fn,
places=self.places,
return_list=True,
num_workers=self.params["num_workers"])
else:
loader = DataLoader(
dataset,
places=self.places,
batch_size=batch_size,
drop_last=False,
return_list=True,
shuffle=False,
num_workers=self.params["num_workers"])
return loader
signal.signal(signal.SIGINT, term_mp) signal.signal(signal.SIGINT, term_mp)
......
...@@ -49,7 +49,6 @@ class Loss(object): ...@@ -49,7 +49,6 @@ class Loss(object):
input = -F.log_softmax(input, axis=-1) input = -F.log_softmax(input, axis=-1)
cost = paddle.reduce_sum(target * input, dim=-1) cost = paddle.reduce_sum(target * input, dim=-1)
else: else:
# softmax_out = F.softmax(input)
cost = F.cross_entropy(input=input, label=target) cost = F.cross_entropy(input=input, label=target)
avg_cost = paddle.mean(cost) avg_cost = paddle.mean(cost)
return avg_cost return avg_cost
......
...@@ -63,10 +63,9 @@ def main(args): ...@@ -63,10 +63,9 @@ def main(args):
net = program.create_model(config.ARCHITECTURE, config.classes_num) net = program.create_model(config.ARCHITECTURE, config.classes_num)
net = paddle.DataParallel(net, strategy) net = paddle.DataParallel(net, strategy)
init_model(config, net, optimizer=None) init_model(config, net, optimizer=None)
valid_dataloader = program.create_dataloader() valid_dataloader = Reader(config, 'valid', places=place)()
valid_reader = Reader(config, 'valid')()
valid_dataloader.set_sample_list_generator(valid_reader, place)
net.eval() net.eval()
top1_acc = program.run(valid_dataloader, config, net, None, None, 0, top1_acc = program.run(valid_dataloader, config, net, None, None, 0,
'valid') 'valid')
......
...@@ -23,7 +23,6 @@ __dir__ = os.path.dirname(os.path.abspath(__file__)) ...@@ -23,7 +23,6 @@ __dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__) sys.path.append(__dir__)
sys.path.append(os.path.abspath(os.path.join(__dir__, '..'))) sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
import paddle import paddle
from paddle.distributed import ParallelEnv from paddle.distributed import ParallelEnv
...@@ -33,6 +32,7 @@ from ppcls.utils.save_load import init_model, save_model ...@@ -33,6 +32,7 @@ from ppcls.utils.save_load import init_model, save_model
from ppcls.utils import logger from ppcls.utils import logger
import program import program
def parse_args(): def parse_args():
parser = argparse.ArgumentParser("PaddleClas train script") parser = argparse.ArgumentParser("PaddleClas train script")
parser.add_argument( parser.add_argument(
...@@ -78,16 +78,13 @@ def main(args): ...@@ -78,16 +78,13 @@ def main(args):
# load model from checkpoint or pretrained model # load model from checkpoint or pretrained model
init_model(config, net, optimizer) init_model(config, net, optimizer)
train_dataloader = program.create_dataloader() train_dataloader = Reader(config, 'train', places=place)()
train_reader = Reader(config, 'train')()
train_dataloader.set_sample_list_generator(train_reader, place)
if config.validate: if config.validate and ParallelEnv().local_rank == 0:
valid_dataloader = program.create_dataloader() valid_dataloader = Reader(config, 'valid', places=place)()
valid_reader = Reader(config, 'valid')()
valid_dataloader.set_sample_list_generator(valid_reader, place)
best_top1_acc = 0.0 # best top1 acc record best_top1_acc = 0.0 # best top1 acc record
best_top1_epoch = 0
for epoch_id in range(config.epochs): for epoch_id in range(config.epochs):
net.train() net.train()
# 1. train with train dataset # 1. train with train dataset
...@@ -98,18 +95,18 @@ def main(args): ...@@ -98,18 +95,18 @@ def main(args):
# 2. validate with validate dataset # 2. validate with validate dataset
if config.validate and epoch_id % config.valid_interval == 0: if config.validate and epoch_id % config.valid_interval == 0:
net.eval() net.eval()
top1_acc = program.run(valid_dataloader, config, net, None, None, top1_acc = program.run(valid_dataloader, config, net, None,
epoch_id, 'valid') None, epoch_id, 'valid')
if top1_acc > best_top1_acc: if top1_acc > best_top1_acc:
best_top1_acc = top1_acc best_top1_acc = top1_acc
message = "The best top1 acc {:.5f}, in epoch: {:d}".format( best_top1_epoch = epoch_id
best_top1_acc, epoch_id)
logger.info("{:s}".format(logger.coloring(message, "RED")))
if epoch_id % config.save_interval == 0: if epoch_id % config.save_interval == 0:
model_path = os.path.join(config.model_save_dir, model_path = os.path.join(config.model_save_dir,
config.ARCHITECTURE["name"]) config.ARCHITECTURE["name"])
save_model(net, optimizer, model_path, "best_model") save_model(net, optimizer, model_path, "best_model")
message = "The best top1 acc {:.5f}, in epoch: {:d}".format(
best_top1_acc, best_top1_epoch)
logger.info("{:s}".format(logger.coloring(message, "RED")))
# 3. save the persistable model # 3. save the persistable model
if epoch_id % config.save_interval == 0: if epoch_id % config.save_interval == 0:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册