提交 f2e358b4 编写于 作者: Y Yancey1989

add readme

上级 3aa8cb2f
# PaddlePaddle Fast ResNet
PaddlePaddle Fast ResNet can train ImageNet with fewer epochs. We implemented the it according to the blog
[Now anyone can train Imagenet in 18 minutes](https://www.fast.ai/2018/08/10/fastai-diu-imagenet/) which published on the [fast.ai] website.
PaddlePaddle Fast ResNet using the dynmiac batch size, dynamic image size, rectangular images validation and etc... so that the FastResNet can achieve the baseline
(acc1: 75%, acc5: 93%) by 27 epochs on 8 GPUs.
## Experiment
1. Preparing the training data, resize the images to 160 and 352 by `resize.py`, the prepared data folder is as followed:
``` text
`-ImageNet
|-train
|-validation
|-160
|-train
`-validation
`-352
|-train
`-validation
```
1. Install the requirements by `pip install -r requirement.txt`.
1. Launch the training job: `python train.py --data_dir /data/imagenet`
1. Learning curve, we launch the training job on V100 GPU card:
<p align="center">
<img src="src/acc_curve.png" hspace='10' /> <br />
</p>
import os
def dist_env():
"""
Return a dict of all variable that distributed training may use.
NOTE: you may rewrite this function to suit your cluster environments.
"""
trainer_id = int(os.getenv("PADDLE_TRAINER_ID", "0"))
num_trainers = 1
training_role = os.getenv("TRAINING_ROLE", "TRAINER")
assert(training_role == "PSERVER" or training_role == "TRAINER")
# - PADDLE_TRAINER_ENDPOINTS means nccl2 mode.
# - PADDLE_PSERVER_ENDPOINTS means pserver mode.
# - PADDLE_CURRENT_ENDPOINT means current process endpoint.
worker_endpoints = []
port = os.getenv("PADDLE_PORT", "8701")
if os.getenv("PADDLE_TRAINER_ENDPOINTS"):
trainer_endpoints = os.getenv("PADDLE_TRAINER_ENDPOINTS")
else:# for paddlecloud
worker_ips = os.getenv("PADDLE_TRAINERS", "")
for ip in worker_ips.split(","):
worker_endpoints.append(':'.join([ip, port]))
trainer_endpoints = ",".join(worker_endpoints)
pserver_ips = os.getenv("PADDLE_PSERVERS", "")
eplist = []
for ip in pserver_ips.split(","):
eplist.append(':'.join([ip, port]))
pserver_endpoints = ",".join(eplist)
if os.getenv("PADDLE_CURRENT_ENDPOINT"):
current_endpoint = os.getenv("PADDLE_CURRENT_ENDPOINT")
else:# for paddlecloud
current_endpoint = os.getenv("POD_IP", "") + ":" + port
if trainer_endpoints:
trainer_endpoints = trainer_endpoints.split(",")
num_trainers = len(trainer_endpoints)
elif pserver_endpoints:
num_trainers = int(os.getenv("PADDLE_TRAINERS_NUM", "1"))
return {
"trainer_id": trainer_id,
"num_trainers": num_trainers,
"current_endpoint": current_endpoint,
"training_role": training_role,
"pserver_endpoints": pserver_endpoints,
"trainer_endpoints": trainer_endpoints
}
import os
import numpy as np
import math
import random
import torchvision
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import pickle
from tqdm import tqdm
import time
import multiprocessing
TRAINER_NUMS = int(os.getenv("PADDLE_TRAINER_NUM", "1"))
TRAINER_ID = int(os.getenv("PADDLE_TRAINER_ID", "0"))
epoch = 0
class ImageFolder(object):
def __init__(self, root, transforms):
pass
FINISH_EVENT = "FINISH_EVENT"
class PaddleDataLoader(object):
def __init__(self, torch_dataset, indices=None, concurrent=16, queue_size=3072):
self.torch_dataset = torch_dataset
self.data_queue = multiprocessing.Queue(queue_size)
self.indices = indices
self.concurrent = concurrent
def _worker_loop(self, dataset, worker_indices, worker_id):
cnt = 0
for idx in worker_indices:
cnt += 1
img, label = self.torch_dataset[idx]
img = np.array(img).astype('uint8').transpose((2, 0, 1))
self.data_queue.put((img, label))
print("worker: [%d] read [%d] samples. " % (worker_id, cnt))
self.data_queue.put(FINISH_EVENT)
def reader(self):
def _reader_creator():
worker_processes = []
total_img = len(self.torch_dataset)
print("total image: ", total_img)
if self.indices is None:
self.indices = [i for i in xrange(total_img)]
random.seed(time.time())
random.shuffle(self.indices)
print("shuffle indices: %s ..." % self.indices[:10])
imgs_per_worker = int(math.ceil(total_img / self.concurrent))
for i in xrange(self.concurrent):
start = i * imgs_per_worker
end = (i + 1) * imgs_per_worker if i != self.concurrent - 1 else None
sliced_indices = self.indices[start:end]
w = multiprocessing.Process(
target=self._worker_loop,
args=(self.torch_dataset, sliced_indices, i)
)
w.daemon = True
w.start()
worker_processes.append(w)
finish_workers = 0
worker_cnt = len(worker_processes)
while finish_workers < worker_cnt:
sample = self.data_queue.get()
if sample == FINISH_EVENT:
finish_workers += 1
else:
yield sample
return _reader_creator
def train(traindir, sz, min_scale=0.08):
train_tfms = [
transforms.RandomResizedCrop(sz, scale=(min_scale, 1.0)),
transforms.RandomHorizontalFlip()
]
train_dataset = datasets.ImageFolder(traindir, transforms.Compose(train_tfms))
return PaddleDataLoader(train_dataset).reader()
def test(valdir, bs, sz, rect_val=False):
if rect_val:
idx_ar_sorted = sort_ar(valdir)
idx_sorted, _ = zip(*idx_ar_sorted)
idx2ar = map_idx2ar(idx_ar_sorted, bs)
ar_tfms = [transforms.Resize(int(sz* 1.14)), CropArTfm(idx2ar, sz)]
val_dataset = ValDataset(valdir, transform=ar_tfms)
return PaddleDataLoader(val_dataset, concurrent=1, indices=idx_sorted).reader()
val_tfms = [transforms.Resize(int(sz* 1.14)), transforms.CenterCrop(sz)]
val_dataset = datasets.ImageFolder(valdir, transforms.Compose(val_tfms))
return PaddleDataLoader(val_dataset).reader()
class ValDataset(datasets.ImageFolder):
def __init__(self, root, transform=None, target_transform=None):
super(ValDataset, self).__init__(root, transform, target_transform)
def __getitem__(self, index):
path, target = self.imgs[index]
sample = self.loader(path)
if self.transform is not None:
for tfm in self.transform:
if isinstance(tfm, CropArTfm):
sample = tfm(sample, index)
else:
sample = tfm(sample)
if self.target_transform is not None:
target = self.target_transform(target)
return sample, target
class CropArTfm(object):
def __init__(self, idx2ar, target_size):
self.idx2ar, self.target_size = idx2ar, target_size
def __call__(self, img, idx):
target_ar = self.idx2ar[idx]
if target_ar < 1:
w = int(self.target_size / target_ar)
size = (w // 8 * 8, self.target_size)
else:
h = int(self.target_size * target_ar)
size = (self.target_size, h // 8 * 8)
return transforms.functional.center_crop(img, size)
def sort_ar(valdir):
idx2ar_file = valdir + '/../sorted_idxar.p'
if os.path.isfile(idx2ar_file):
return pickle.load(open(idx2ar_file, 'rb'))
print('Creating AR indexes. Please be patient this may take a couple minutes...')
val_dataset = datasets.ImageFolder(valdir) # AS: TODO: use Image.open instead of looping through dataset
sizes = [img[0].size for img in tqdm(val_dataset, total=len(val_dataset))]
idx_ar = [(i, round(s[0] * 1.0/ s[1], 5)) for i, s in enumerate(sizes)]
sorted_idxar = sorted(idx_ar, key=lambda x: x[1])
pickle.dump(sorted_idxar, open(idx2ar_file, 'wb'))
print('Done')
return sorted_idxar
def chunks(l, n):
n = max(1, n)
return (l[i:i + n] for i in range(0, len(l), n))
def map_idx2ar(idx_ar_sorted, batch_size):
ar_chunks = list(chunks(idx_ar_sorted, batch_size))
idx2ar = {}
for chunk in ar_chunks:
idxs, ars = list(zip(*chunk))
mean = round(np.mean(ars), 5)
for idx in idxs:
idx2ar[idx] = mean
return idx2ar
if __name__ == "__main__":
#ds, sampler = create_validation_set("/data/imagenet/validation", 128, 288, True, True)
#for item in sampler:
# for idx in item:
# ds[idx]
import time
test_reader = test(valdir="/data/imagenet/validation", bs=50, sz=288, rect_val=True)
start_ts = time.time()
for idx, data in enumerate(test_reader()):
print(idx, data[0].shape, data[1])
if idx == 10:
break
if (idx + 1) % 1000 == 0:
cost = (time.time() - start_ts)
print("%d samples per second" % (1000 / cost))
start_ts = time.time()
\ No newline at end of file
from PIL import Image
from pathlib import Path
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor
from functools import partial
import multiprocessing
cpus = multiprocessing.cpu_count()
cpus = min(36,cpus)
PATH = Path('/data/imagenet2')
DEST = Path('/data/imagenet2/sz')
def mkdir(path):
if not path.exists():
path.mkdir()
mkdir(DEST)
#szs = (160, 352)
szs = (160,)
def resize_img(p, im, fn, sz):
w,h = im.size
ratio = min(h/sz,w/sz)
im = im.resize((int(w/ratio), int(h/ratio)), resample=Image.BICUBIC)
new_fn = DEST/str(sz)/fn.relative_to(PATH)
mkdir(new_fn.parent())
im.save(new_fn)
def resizes(p, fn):
im = Image.open(fn)
for sz in szs: resize_img(p, im, fn, sz)
def resize_imgs(p):
files = p.glob('*/*.jpeg')
with ProcessPoolExecutor(cpus) as e: e.map(partial(resizes, p), files)
for sz in szs:
ssz=str(sz)
mkdir((DEST/ssz))
for ds in ('validation','train'): mkdir((DEST/ssz/ds))
for ds in ('train',): mkdir((DEST/ssz/ds))
#for ds in ('val','train'): resize_imgs(PATH/ds)
#for ds in ("validation", "train"):
for ds in ("validation", ):
print(PATH/ds)
resize_imgs(PATH/ds)
\ No newline at end of file
......@@ -33,7 +33,6 @@ from utility import add_arguments, print_arguments
import functools
from models.fast_resnet import FastResNet, lr_decay
import utils
from env import dist_env
def parse_args():
parser = argparse.ArgumentParser(description=__doc__)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册