提交 29153a7b 编写于 作者: Y Yancey1989

merge reader

上级 05d1f586
...@@ -5,6 +5,8 @@ large-scaled distributed training with two distributed architecture: parameter s ...@@ -5,6 +5,8 @@ large-scaled distributed training with two distributed architecture: parameter s
## Getting Started ## Getting Started
Before getting started, please make sure you have finished the imagenet [Data Preparation](../README.md#data-preparation).
1. The entrypoint file is `dist_train.py`, some important flags are as follows: 1. The entrypoint file is `dist_train.py`, some important flags are as follows:
- `model`, the model to run with, such as `ResNet50`, `ResNet101` and etc.. - `model`, the model to run with, such as `ResNet50`, `ResNet101` and etc..
...@@ -43,7 +45,8 @@ In this example, we launched 4 parameter server instances and 4 trainer instance ...@@ -43,7 +45,8 @@ In this example, we launched 4 parameter server instances and 4 trainer instance
--model=ResNet50 \ --model=ResNet50 \
--batch_size=32 \ --batch_size=32 \
--update_method=pserver \ --update_method=pserver \
--device=CPU --device=CPU \
--data_dir=../data/ILSVRC2012
``` ```
1. launch trainer process 1. launch trainer process
...@@ -58,7 +61,8 @@ In this example, we launched 4 parameter server instances and 4 trainer instance ...@@ -58,7 +61,8 @@ In this example, we launched 4 parameter server instances and 4 trainer instance
--model=ResNet50 \ --model=ResNet50 \
--batch_size=32 \ --batch_size=32 \
--update_method=pserver \ --update_method=pserver \
--device=GPU --device=GPU \
--data_dir=../data/ILSVRC2012
``` ```
...@@ -75,7 +79,8 @@ In this example, we launched 4 parameter server instances and 4 trainer instance ...@@ -75,7 +79,8 @@ In this example, we launched 4 parameter server instances and 4 trainer instance
--model=ResNet50 \ --model=ResNet50 \
--batch_size=32 \ --batch_size=32 \
--update_method=pserver \ --update_method=pserver \
--device=GPU --device=GPU \
--data_dir=../data/ILSVRC2012
``` ```
### Training Curve ### Training Curve
......
...@@ -123,5 +123,11 @@ def parse_args(): ...@@ -123,5 +123,11 @@ def parse_args():
choices=['reduce', 'all_reduce'], choices=['reduce', 'all_reduce'],
default='all_reduce', default='all_reduce',
help='Specify the reduce strategy, can be reduce, all_reduce') help='Specify the reduce strategy, can be reduce, all_reduce')
parser.add_argument(
'--data_dir',
type=str,
default="../data/ILSVRC2012",
help="The ImageNet dataset root dir."
)
args = parser.parse_args() args = parser.parse_args()
return args return args
...@@ -26,9 +26,8 @@ import paddle.fluid.transpiler.distribute_transpiler as distribute_transpiler ...@@ -26,9 +26,8 @@ import paddle.fluid.transpiler.distribute_transpiler as distribute_transpiler
import sys import sys
sys.path.append("..") sys.path.append("..")
import models import models
from imagenet_reader import train, val
from args import * from args import *
from reader import train, val
def get_model(args, is_train, main_prog, startup_prog): def get_model(args, is_train, main_prog, startup_prog):
pyreader = None pyreader = None
...@@ -38,9 +37,9 @@ def get_model(args, is_train, main_prog, startup_prog): ...@@ -38,9 +37,9 @@ def get_model(args, is_train, main_prog, startup_prog):
else: else:
dshape = [224, 224, 3] dshape = [224, 224, 3]
if is_train: if is_train:
reader = train(xmap=False) reader = train(data_dir=args.data_dir)
else: else:
reader = val(xmap=False) reader = val(data_dir=args.data_dir)
trainer_count = int(os.getenv("PADDLE_TRAINERS", "1")) trainer_count = int(os.getenv("PADDLE_TRAINERS", "1"))
with fluid.program_guard(main_prog, startup_prog): with fluid.program_guard(main_prog, startup_prog):
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import math
import random
import functools
import numpy as np
from threading import Thread
import subprocess
import time
from Queue import Queue
import paddle
from PIL import Image, ImageEnhance
random.seed(0)
DATA_DIM = 224
THREAD = int(os.getenv("PREPROCESS_THREADS", "10"))
BUF_SIZE = 5120
DATA_DIR = '../data/ILSVRC2012'
TRAIN_LIST = '../data/ILSVRC2012/train.txt'
TEST_LIST = '../data/ILSVRC2012/val.txt'
img_mean = np.array([0.485, 0.456, 0.406]).reshape((3, 1, 1))
img_std = np.array([0.229, 0.224, 0.225]).reshape((3, 1, 1))
def resize_short(img, target_size):
percent = float(target_size) / min(img.size[0], img.size[1])
resized_width = int(round(img.size[0] * percent))
resized_height = int(round(img.size[1] * percent))
img = img.resize((resized_width, resized_height), Image.LANCZOS)
return img
def crop_image(img, target_size, center):
width, height = img.size
size = target_size
if center == True:
w_start = (width - size) / 2
h_start = (height - size) / 2
else:
w_start = random.randint(0, width - size)
h_start = random.randint(0, height - size)
w_end = w_start + size
h_end = h_start + size
img = img.crop((w_start, h_start, w_end, h_end))
return img
def random_crop(img, size, scale=[0.08, 1.0], ratio=[3. / 4., 4. / 3.]):
aspect_ratio = math.sqrt(random.uniform(*ratio))
w = 1. * aspect_ratio
h = 1. / aspect_ratio
bound = min((float(img.size[0]) / img.size[1]) / (w**2),
(float(img.size[1]) / img.size[0]) / (h**2))
scale_max = min(scale[1], bound)
scale_min = min(scale[0], bound)
target_area = img.size[0] * img.size[1] * random.uniform(scale_min,
scale_max)
target_size = math.sqrt(target_area)
w = int(target_size * w)
h = int(target_size * h)
i = random.randint(0, img.size[0] - w)
j = random.randint(0, img.size[1] - h)
img = img.crop((i, j, i + w, j + h))
img = img.resize((size, size), Image.LANCZOS)
return img
def rotate_image(img):
angle = random.randint(-10, 10)
img = img.rotate(angle)
return img
def distort_color(img):
def random_brightness(img, lower=0.5, upper=1.5):
e = random.uniform(lower, upper)
return ImageEnhance.Brightness(img).enhance(e)
def random_contrast(img, lower=0.5, upper=1.5):
e = random.uniform(lower, upper)
return ImageEnhance.Contrast(img).enhance(e)
def random_color(img, lower=0.5, upper=1.5):
e = random.uniform(lower, upper)
return ImageEnhance.Color(img).enhance(e)
ops = [random_brightness, random_contrast, random_color]
random.shuffle(ops)
img = ops[0](img)
img = ops[1](img)
img = ops[2](img)
return img
def process_image(sample, mode, color_jitter, rotate):
img_path = sample[0]
img = Image.open(img_path)
if mode == 'train':
if rotate: img = rotate_image(img)
img = random_crop(img, DATA_DIM)
else:
img = resize_short(img, target_size=256)
img = crop_image(img, target_size=DATA_DIM, center=True)
if mode == 'train':
if color_jitter:
img = distort_color(img)
if random.randint(0, 1) == 1:
img = img.transpose(Image.FLIP_LEFT_RIGHT)
if img.mode != 'RGB':
img = img.convert('RGB')
img = np.array(img).astype('float32').transpose((2, 0, 1)) / 255
img -= img_mean
img /= img_std
if mode == 'train' or mode == 'val':
return img, sample[1]
elif mode == 'test':
return [img]
class XmapEndSignal():
pass
def xmap_readers(mapper,
reader,
process_num,
buffer_size,
order=False,
print_queue_state=True):
end = XmapEndSignal()
# define a worker to read samples from reader to in_queue
def read_worker(reader, in_queue):
for i in reader():
in_queue.put(i)
in_queue.put(end)
# define a worker to read samples from reader to in_queue with order flag
def order_read_worker(reader, in_queue, file_queue):
in_order = 0
for i in reader():
in_queue.put((in_order, i))
in_order += 1
in_queue.put(end)
# define a worker to handle samples from in_queue by mapper
# and put mapped samples into out_queue
def handle_worker(in_queue, out_queue, mapper):
sample = in_queue.get()
while not isinstance(sample, XmapEndSignal):
r = mapper(sample)
out_queue.put(r)
sample = in_queue.get()
in_queue.put(end)
out_queue.put(end)
# define a worker to handle samples from in_queue by mapper
# and put mapped samples into out_queue by order
def order_handle_worker(in_queue, out_queue, mapper, out_order):
ins = in_queue.get()
while not isinstance(ins, XmapEndSignal):
order, sample = ins
r = mapper(sample)
while order != out_order[0]:
pass
out_queue.put(r)
out_order[0] += 1
ins = in_queue.get()
in_queue.put(end)
out_queue.put(end)
def xreader():
file_queue = Queue()
in_queue = Queue(buffer_size)
out_queue = Queue(buffer_size)
out_order = [0]
# start a read worker in a thread
target = order_read_worker if order else read_worker
t = Thread(target=target, args=(reader, in_queue))
t.daemon = True
t.start()
# start several handle_workers
target = order_handle_worker if order else handle_worker
args = (in_queue, out_queue, mapper, out_order) if order else (
in_queue, out_queue, mapper)
workers = []
for i in xrange(process_num):
worker = Thread(target=target, args=args)
worker.daemon = True
workers.append(worker)
for w in workers:
w.start()
sample = out_queue.get()
start_t = time.time()
while not isinstance(sample, XmapEndSignal):
yield sample
sample = out_queue.get()
if time.time() - start_t > 3:
if print_queue_state:
print("queue sizes: ", in_queue.qsize(), out_queue.qsize())
start_t = time.time()
finish = 1
while finish < process_num:
sample = out_queue.get()
if isinstance(sample, XmapEndSignal):
finish += 1
else:
yield sample
return xreader
def _reader_creator(file_list,
mode,
shuffle=False,
color_jitter=False,
rotate=False,
xmap=True):
def reader():
with open(file_list) as flist:
full_lines = [line.strip() for line in flist]
if shuffle:
random.shuffle(full_lines)
if mode == 'train':
trainer_id = int(os.getenv("PADDLE_TRAINER_ID", "0"))
trainer_count = int(os.getenv("PADDLE_TRAINERS", "1"))
per_node_lines = len(full_lines) / trainer_count
lines = full_lines[trainer_id * per_node_lines:(trainer_id + 1)
* per_node_lines]
print(
"read images from %d, length: %d, lines length: %d, total: %d"
% (trainer_id * per_node_lines, per_node_lines, len(lines),
len(full_lines)))
else:
lines = full_lines
for line in lines:
if mode == 'train':
img_path, label = line.split()
img_path = img_path.replace("JPEG", "jpeg")
img_path = os.path.join(DATA_DIR, "train", img_path)
yield (img_path, int(label))
elif mode == 'val':
img_path, label = line.split()
img_path = img_path.replace("JPEG", "jpeg")
img_path = os.path.join(DATA_DIR, "val", img_path)
yield (img_path, int(label))
elif mode == 'test':
img_path = os.path.join(DATA_DIR, line)
yield [img_path]
mapper = functools.partial(
process_image, mode=mode, color_jitter=color_jitter, rotate=rotate)
return paddle.reader.xmap_readers(mapper, reader, THREAD, BUF_SIZE)
def load_raw_image_uint8(sample):
img_arr = np.array(Image.open(sample[0])).astype('int64')
return img_arr, int(sample[1])
def train_raw(file_list=TRAIN_LIST, shuffle=True):
def reader():
with open(file_list) as flist:
full_lines = [line.strip() for line in flist]
if shuffle:
random.shuffle(full_lines)
trainer_id = int(os.getenv("PADDLE_TRAINER_ID"))
trainer_count = int(os.getenv("PADDLE_TRAINERS"))
per_node_lines = len(full_lines) / trainer_count
lines = full_lines[trainer_id * per_node_lines:(trainer_id + 1) *
per_node_lines]
print("read images from %d, length: %d, lines length: %d, total: %d"
% (trainer_id * per_node_lines, per_node_lines, len(lines),
len(full_lines)))
for line in lines:
img_path, label = line.split()
img_path = img_path.replace("JPEG", "jpeg")
img_path = os.path.join(DATA_DIR, "train", img_path)
yield (img_path, int(label))
return paddle.reader.xmap_readers(load_raw_image_uint8, reader, THREAD,
BUF_SIZE)
def train(file_list=TRAIN_LIST, xmap=True):
return _reader_creator(
file_list,
'train',
shuffle=True,
color_jitter=False,
rotate=False,
xmap=xmap)
def val(file_list=TEST_LIST, xmap=True):
return _reader_creator(file_list, 'val', shuffle=False, xmap=xmap)
def test(file_list=TEST_LIST):
return _reader_creator(file_list, 'test', shuffle=False)
if __name__ == "__main__":
c = 0
start_t = time.time()
for d in train()():
c += 1
if c >= 10000:
break
spent = time.time() - start_t
print("read 10000 speed: ", 10000 / spent, spent)
...@@ -15,8 +15,6 @@ THREAD = 8 ...@@ -15,8 +15,6 @@ THREAD = 8
BUF_SIZE = 102400 BUF_SIZE = 102400
DATA_DIR = 'data/ILSVRC2012' DATA_DIR = 'data/ILSVRC2012'
TRAIN_LIST = 'data/ILSVRC2012/train_list.txt'
TEST_LIST = 'data/ILSVRC2012/val_list.txt'
img_mean = np.array([0.485, 0.456, 0.406]).reshape((3, 1, 1)) img_mean = np.array([0.485, 0.456, 0.406]).reshape((3, 1, 1))
img_std = np.array([0.229, 0.224, 0.225]).reshape((3, 1, 1)) img_std = np.array([0.229, 0.224, 0.225]).reshape((3, 1, 1))
...@@ -131,19 +129,35 @@ def _reader_creator(file_list, ...@@ -131,19 +129,35 @@ def _reader_creator(file_list,
mode, mode,
shuffle=False, shuffle=False,
color_jitter=False, color_jitter=False,
rotate=False): rotate=False,
data_dir=DATA_DIR):
def reader(): def reader():
with open(file_list) as flist: with open(file_list) as flist:
lines = [line.strip() for line in flist] full_lines = [line.strip() for line in flist]
if shuffle: if shuffle:
np.random.shuffle(lines) np.random.shuffle(full_lines)
if mode == 'train' and os.getenv('PADDLE_TRAINING_ROLE'):
# distributed mode if the env var `PADDLE_TRAINING_ROLE` exits
trainer_id = int(os.getenv("PADDLE_TRAINER_ID", "0"))
trainer_count = int(os.getenv("PADDLE_TRAINERS", "1"))
per_node_lines = len(full_lines) / trainer_count
lines = full_lines[trainer_id * per_node_lines:(trainer_id + 1)
* per_node_lines]
print(
"read images from %d, length: %d, lines length: %d, total: %d"
% (trainer_id * per_node_lines, per_node_lines, len(lines),
len(full_lines)))
else:
lines = full_lines
for line in lines: for line in lines:
if mode == 'train' or mode == 'val': if mode == 'train' or mode == 'val':
img_path, label = line.split() img_path, label = line.split()
img_path = os.path.join(DATA_DIR, img_path) img_path = img_path.replace("JPEG", "jpeg")
img_path = os.path.join(data_dir, img_path)
yield img_path, int(label) yield img_path, int(label)
elif mode == 'test': elif mode == 'test':
img_path = os.path.join(DATA_DIR, line) img_path = os.path.join(data_dir, line)
yield [img_path] yield [img_path]
mapper = functools.partial( mapper = functools.partial(
...@@ -152,14 +166,17 @@ def _reader_creator(file_list, ...@@ -152,14 +166,17 @@ def _reader_creator(file_list,
return paddle.reader.xmap_readers(mapper, reader, THREAD, BUF_SIZE) return paddle.reader.xmap_readers(mapper, reader, THREAD, BUF_SIZE)
def train(file_list=TRAIN_LIST): def train(data_dir=DATA_DIR):
file_list = os.path.join(data_dir, 'train_list.txt')
return _reader_creator( return _reader_creator(
file_list, 'train', shuffle=True, color_jitter=False, rotate=False) file_list, 'train', shuffle=True, color_jitter=False, rotate=False, data_dir=data_dir)
def val(file_list=TEST_LIST): def val(data_dir=DATA_DIR):
return _reader_creator(file_list, 'val', shuffle=False) file_list = os.path.join(data_dir, 'val_list.txt')
return _reader_creator(file_list, 'val', shuffle=False, data_dir=data_dir)
def test(file_list=TEST_LIST): def test(data_dir=DATA_DIR):
return _reader_creator(file_list, 'test', shuffle=False) file_list = os.path.join(data_dir, 'val_list.txt')
return _reader_creator(file_list, 'test', shuffle=False, data_dir=data_dir)
...@@ -33,6 +33,7 @@ add_arg('lr', float, 0.1, "set learning rate.") ...@@ -33,6 +33,7 @@ add_arg('lr', float, 0.1, "set learning rate.")
add_arg('lr_strategy', str, "piecewise_decay", "Set the learning rate decay strategy.") add_arg('lr_strategy', str, "piecewise_decay", "Set the learning rate decay strategy.")
add_arg('model', str, "SE_ResNeXt50_32x4d", "Set the network to use.") add_arg('model', str, "SE_ResNeXt50_32x4d", "Set the network to use.")
add_arg('enable_ce', bool, False, "If set True, enable continuous evaluation job.") add_arg('enable_ce', bool, False, "If set True, enable continuous evaluation job.")
add_arg('data_dir' str, "./data/ILSVRC2012", "The ImageNet dataset root dir.")
# yapf: enable # yapf: enable
model_list = [m for m in dir(models) if "__" not in m] model_list = [m for m in dir(models) if "__" not in m]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册