diff --git a/fluid/image_classification/dist_train/README.md b/fluid/image_classification/dist_train/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..02bbea17f423fe2e16fd3115058ba92805a313ab
--- /dev/null
+++ b/fluid/image_classification/dist_train/README.md
@@ -0,0 +1,113 @@
+# Distributed Image Classification Models Training
+
+This folder contains implementations of **Image Classification Models**, they are designed to support
+large-scaled distributed training with two distributed mode: parameter server mode and NCCL2(Nvidia NCCL2 communication library) collective mode.
+
+## Getting Started
+
+Before getting started, please make sure you have go throught the imagenet [Data Preparation](../README.md#data-preparation).
+
+1. The entrypoint file is `dist_train.py`, some important flags are as follows:
+
+ - `model`, the model to run with, such as `ResNet50`, `ResNet101` and etc..
+ - `batch_size`, the batch_size per device.
+ - `update_method`, specify the update method, can choose from local, pserver or nccl2.
+ - `device`, use CPU or GPU device.
+ - `gpus`, the GPU device count that the process used.
+
+ you can check out more details of the flags by `python dist_train.py --help`.
+
+1. Runtime configurations
+
+ We use the environment variable to distinguish the different training role of a distributed training job.
+
+ - `PADDLE_TRAINING_ROLE`, the current training role, should be in [PSERVER, TRAINER].
+ - `PADDLE_TRAINERS`, the trainer count of a job.
+ - `PADDLE_CURRENT_IP`, the current instance IP.
+ - `PADDLE_PSERVER_IPS`, the parameter server IP list, separated by "," only be used with update_method is pserver.
+ - `PADDLE_TRAINER_ID`, the unique trainer ID of a job, the ranging is [0, PADDLE_TRAINERS).
+ - `PADDLE_PSERVER_PORT`, the port of the parameter pserver listened on.
+ - `PADDLE_TRAINER_IPS`, the trainer IP list, separated by ",", only be used with upadte_method is nccl2.
+
+### Parameter Server Mode
+
+In this example, we launched 4 parameter server instances and 4 trainer instances in the cluster:
+
+1. launch parameter server process
+
+ ``` python
+ PADDLE_TRAINING_ROLE=PSERVER \
+ PADDLE_TRAINERS=4 \
+ PADDLE_PSERVER_IPS=192.168.0.100,192.168.0.101,192.168.0.102,192.168.0.103 \
+ PADDLE_CURRENT_IP=192.168.0.100 \
+ PADDLE_PSERVER_PORT=7164 \
+ python dist_train.py \
+ --model=ResNet50 \
+ --batch_size=32 \
+ --update_method=pserver \
+ --device=CPU \
+ --data_dir=../data/ILSVRC2012
+ ```
+
+1. launch trainer process
+
+ ``` python
+ PADDLE_TRAINING_ROLE=PSERVER \
+ PADDLE_TRAINERS=4 \
+ PADDLE_PSERVER_IPS=192.168.0.100,192.168.0.101,192.168.0.102,192.168.0.103 \
+ PADDLE_TRAINER_ID=0 \
+ PADDLE_PSERVER_PORT=7164 \
+ python dist_train.py \
+ --model=ResNet50 \
+ --batch_size=32 \
+ --update_method=pserver \
+ --device=GPU \
+ --data_dir=../data/ILSVRC2012
+
+ ```
+
+### NCCL2 Collective Mode
+
+1. launch trainer process
+
+ ``` python
+ PADDLE_TRAINING_ROLE=TRAINER \
+ PADDLE_TRAINERS=4 \
+ PADDLE_TRAINER_IPS=192.168.0.100,192.168.0.101,192.168.0.102,192.168.0.103 \
+ PADDLE_TRAINER_ID=0 \
+ python dist_train.py \
+ --model=ResNet50 \
+ --batch_size=32 \
+ --update_method=pserver \
+ --device=GPU \
+ --data_dir=../data/ILSVRC2012
+ ```
+
+### Visualize the Training Process
+
+It's easy to draw the learning curve accroding to the training logs, for example,
+the logs of ResNet50 is as follows:
+
+``` text
+Pass 0, batch 0, loss 7.0336914, accucacys: [0.0, 0.00390625]
+Pass 0, batch 1, loss 7.094781, accucacys: [0.0, 0.0]
+Pass 0, batch 2, loss 7.007068, accucacys: [0.0, 0.0078125]
+Pass 0, batch 3, loss 7.1056547, accucacys: [0.00390625, 0.00390625]
+Pass 0, batch 4, loss 7.133543, accucacys: [0.0, 0.0078125]
+Pass 0, batch 5, loss 7.3055463, accucacys: [0.0078125, 0.01171875]
+Pass 0, batch 6, loss 7.341838, accucacys: [0.0078125, 0.01171875]
+Pass 0, batch 7, loss 7.290557, accucacys: [0.0, 0.0]
+Pass 0, batch 8, loss 7.264951, accucacys: [0.0, 0.00390625]
+Pass 0, batch 9, loss 7.43522, accucacys: [0.00390625, 0.00390625]
+```
+
+The training accucacys top1 of local training, distributed training with NCCL2 and parameter server architecture on the ResNet50 model are shown in the below figure:
+
+
+
+Training acc1 curves
+
+
+### Performance
+
+TBD
\ No newline at end of file
diff --git a/fluid/image_classification/dist_train/__init__.py b/fluid/image_classification/dist_train/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/fluid/image_classification/dist_train/args.py b/fluid/image_classification/dist_train/args.py
new file mode 100644
index 0000000000000000000000000000000000000000..fff9fd11a194abf05e49b76fb89125036b2c893d
--- /dev/null
+++ b/fluid/image_classification/dist_train/args.py
@@ -0,0 +1,118 @@
+# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import argparse
+
+__all__ = ['parse_args', ]
+
+BENCHMARK_MODELS = [
+ "ResNet50", "ResNet101", "ResNet152"
+]
+
+
+def parse_args():
+ parser = argparse.ArgumentParser('Distributed Image Classification Training.')
+ parser.add_argument(
+ '--model',
+ type=str,
+ choices=BENCHMARK_MODELS,
+ default='resnet',
+ help='The model to run benchmark with.')
+ parser.add_argument(
+ '--batch_size', type=int, default=32, help='The minibatch size.')
+ # args related to learning rate
+ parser.add_argument(
+ '--learning_rate', type=float, default=0.001, help='The learning rate.')
+ # TODO(wuyi): add "--use_fake_data" option back.
+ parser.add_argument(
+ '--skip_batch_num',
+ type=int,
+ default=5,
+ help='The first num of minibatch num to skip, for better performance test'
+ )
+ parser.add_argument(
+ '--iterations', type=int, default=80, help='The number of minibatches.')
+ parser.add_argument(
+ '--pass_num', type=int, default=100, help='The number of passes.')
+ parser.add_argument(
+ '--data_format',
+ type=str,
+ default='NCHW',
+ choices=['NCHW', 'NHWC'],
+ help='The data data_format, now only support NCHW.')
+ parser.add_argument(
+ '--device',
+ type=str,
+ default='GPU',
+ choices=['CPU', 'GPU'],
+ help='The device type.')
+ parser.add_argument(
+ '--gpus',
+ type=int,
+ default=1,
+ help='If gpus > 1, will use ParallelExecutor to run, else use Executor.')
+ # this option is available only for vgg and resnet.
+ parser.add_argument(
+ '--cpus',
+ type=int,
+ default=1,
+ help='If cpus > 1, will set ParallelExecutor to use multiple threads.')
+ parser.add_argument(
+ '--data_set',
+ type=str,
+ default='flowers',
+ choices=['cifar10', 'flowers', 'imagenet'],
+ help='Optional dataset for benchmark.')
+ parser.add_argument(
+ '--no_test',
+ action='store_true',
+ help='If set, do not test the testset during training.')
+ parser.add_argument(
+ '--memory_optimize',
+ action='store_true',
+ help='If set, optimize runtime memory before start.')
+ parser.add_argument(
+ '--update_method',
+ type=str,
+ default='local',
+ choices=['local', 'pserver', 'nccl2'],
+ help='Choose parameter update method, can be local, pserver, nccl2.')
+ parser.add_argument(
+ '--no_split_var',
+ action='store_true',
+ default=False,
+ help='Whether split variables into blocks when update_method is pserver')
+ parser.add_argument(
+ '--async_mode',
+ action='store_true',
+ default=False,
+ help='Whether start pserver in async mode to support ASGD')
+ parser.add_argument(
+ '--no_random',
+ action='store_true',
+ help='If set, keep the random seed and do not shuffle the data.')
+ parser.add_argument(
+ '--reduce_strategy',
+ type=str,
+ choices=['reduce', 'all_reduce'],
+ default='all_reduce',
+ help='Specify the reduce strategy, can be reduce, all_reduce')
+ parser.add_argument(
+ '--data_dir',
+ type=str,
+ default="../data/ILSVRC2012",
+ help="The ImageNet dataset root dir."
+ )
+ args = parser.parse_args()
+ return args
diff --git a/fluid/image_classification/dist_train/dist_train.py b/fluid/image_classification/dist_train/dist_train.py
new file mode 100644
index 0000000000000000000000000000000000000000..8f3953d1320441030c3cdee899ec85ffc8f8f401
--- /dev/null
+++ b/fluid/image_classification/dist_train/dist_train.py
@@ -0,0 +1,363 @@
+# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import argparse
+import time
+import os
+import traceback
+
+import numpy as np
+
+import paddle
+import paddle.fluid as fluid
+import paddle.fluid.core as core
+import sys
+sys.path.append("..")
+import models
+from args import *
+from reader import train, val
+
+def get_model(args, is_train, main_prog, startup_prog):
+ pyreader = None
+ class_dim = 1000
+ if args.data_format == 'NCHW':
+ dshape = [3, 224, 224]
+ else:
+ dshape = [224, 224, 3]
+ if is_train:
+ reader = train(data_dir=args.data_dir)
+ else:
+ reader = val(data_dir=args.data_dir)
+
+ trainer_count = int(os.getenv("PADDLE_TRAINERS", "1"))
+ with fluid.program_guard(main_prog, startup_prog):
+ with fluid.unique_name.guard():
+ pyreader = fluid.layers.py_reader(
+ capacity=args.batch_size * args.gpus,
+ shapes=([-1] + dshape, (-1, 1)),
+ dtypes=('float32', 'int64'),
+ name="train_reader" if is_train else "test_reader",
+ use_double_buffer=True)
+ input, label = fluid.layers.read_file(pyreader)
+ model_def = models.__dict__[args.model]()
+ predict = model_def.net(input, class_dim=class_dim)
+
+ cost = fluid.layers.cross_entropy(input=predict, label=label)
+ avg_cost = fluid.layers.mean(x=cost)
+
+ batch_acc1 = fluid.layers.accuracy(input=predict, label=label, k=1)
+ batch_acc5 = fluid.layers.accuracy(input=predict, label=label, k=5)
+
+ # configure optimize
+ optimizer = None
+ if is_train:
+
+ total_images = 1281167 / trainer_count
+
+ step = int(total_images / (args.batch_size * args.gpus) + 1)
+ epochs = [30, 60, 90]
+ bd = [step * e for e in epochs]
+ base_lr = args.learning_rate
+ lr = []
+ lr = [base_lr * (0.1**i) for i in range(len(bd) + 1)]
+ optimizer = fluid.optimizer.Momentum(
+ learning_rate=fluid.layers.piecewise_decay(
+ boundaries=bd, values=lr),
+ momentum=0.9,
+ regularization=fluid.regularizer.L2Decay(1e-4))
+ optimizer.minimize(avg_cost)
+
+ if args.memory_optimize:
+ fluid.memory_optimize(main_prog)
+
+ batched_reader = None
+ pyreader.decorate_paddle_reader(
+ paddle.batch(
+ reader if args.no_random else paddle.reader.shuffle(
+ reader, buf_size=5120),
+ batch_size=args.batch_size))
+
+ return avg_cost, optimizer, [batch_acc1,
+ batch_acc5], batched_reader, pyreader
+
+def append_nccl2_prepare(trainer_id, startup_prog):
+ if trainer_id >= 0:
+ # append gen_nccl_id at the end of startup program
+ trainer_id = int(os.getenv("PADDLE_TRAINER_ID"))
+ port = os.getenv("PADDLE_PSERVER_PORT")
+ worker_ips = os.getenv("PADDLE_TRAINER_IPS")
+ worker_endpoints = []
+ for ip in worker_ips.split(","):
+ worker_endpoints.append(':'.join([ip, port]))
+ num_trainers = len(worker_endpoints)
+ current_endpoint = os.getenv("PADDLE_CURRENT_IP") + ":" + port
+ worker_endpoints.remove(current_endpoint)
+
+ nccl_id_var = startup_prog.global_block().create_var(
+ name="NCCLID",
+ persistable=True,
+ type=fluid.core.VarDesc.VarType.RAW)
+ startup_prog.global_block().append_op(
+ type="gen_nccl_id",
+ inputs={},
+ outputs={"NCCLID": nccl_id_var},
+ attrs={
+ "endpoint": current_endpoint,
+ "endpoint_list": worker_endpoints,
+ "trainer_id": trainer_id
+ })
+ return nccl_id_var, num_trainers, trainer_id
+ else:
+ raise Exception("must set positive PADDLE_TRAINER_ID env variables for "
+ "nccl-based dist train.")
+
+
+def dist_transpile(trainer_id, args, train_prog, startup_prog):
+ if trainer_id < 0:
+ return None, None
+
+ # the port of all pservers, needed by both trainer and pserver
+ port = os.getenv("PADDLE_PSERVER_PORT", "6174")
+ # comma separated ips of all pservers, needed by trainer and
+ # pserver
+ pserver_ips = os.getenv("PADDLE_PSERVER_IPS", "")
+ eplist = []
+ for ip in pserver_ips.split(","):
+ eplist.append(':'.join([ip, port]))
+ pserver_endpoints = ",".join(eplist)
+ # total number of workers/trainers in the job, needed by
+ # trainer and pserver
+ trainers = int(os.getenv("PADDLE_TRAINERS"))
+ # the IP of the local machine, needed by pserver only
+ current_endpoint = os.getenv("PADDLE_CURRENT_IP", "") + ":" + port
+ # the role, should be either PSERVER or TRAINER
+ training_role = os.getenv("PADDLE_TRAINING_ROLE")
+
+ config = fluid.DistributeTranspilerConfig()
+ config.slice_var_up = not args.no_split_var
+ t = fluid.DistributeTranspiler(config=config)
+ t.transpile(
+ trainer_id,
+ # NOTE: *MUST* use train_prog, for we are using with guard to
+ # generate different program for train and test.
+ program=train_prog,
+ pservers=pserver_endpoints,
+ trainers=trainers,
+ sync_mode=not args.async_mode,
+ startup_program=startup_prog)
+ if training_role == "PSERVER":
+ pserver_program = t.get_pserver_program(current_endpoint)
+ pserver_startup_program = t.get_startup_program(
+ current_endpoint, pserver_program, startup_program=startup_prog)
+ return pserver_program, pserver_startup_program
+ elif training_role == "TRAINER":
+ train_program = t.get_trainer_program()
+ return train_program, startup_prog
+ else:
+ raise ValueError(
+ 'PADDLE_TRAINING_ROLE environment variable must be either TRAINER or PSERVER'
+ )
+
+
+def test_parallel(exe, test_args, args, test_prog, feeder):
+ acc_evaluators = []
+ for i in xrange(len(test_args[2])):
+ acc_evaluators.append(fluid.metrics.Accuracy())
+
+ to_fetch = [v.name for v in test_args[2]]
+ test_args[4].start()
+ while True:
+ try:
+ acc_rets = exe.run(fetch_list=to_fetch)
+ for i, e in enumerate(acc_evaluators):
+ e.update(
+ value=np.array(acc_rets[i]), weight=args.batch_size)
+ except fluid.core.EOFException as eof:
+ test_args[4].reset()
+ break
+
+ return [e.eval() for e in acc_evaluators]
+
+
+# NOTE: only need to benchmark using parallelexe
+def train_parallel(train_args, test_args, args, train_prog, test_prog,
+ startup_prog, nccl_id_var, num_trainers, trainer_id):
+ over_all_start = time.time()
+ place = core.CPUPlace() if args.device == 'CPU' else core.CUDAPlace(0)
+ feeder = None
+
+ if nccl_id_var and trainer_id == 0:
+ #FIXME(wuyi): wait other trainer to start listening
+ time.sleep(30)
+
+ startup_exe = fluid.Executor(place)
+ startup_exe.run(startup_prog)
+ strategy = fluid.ExecutionStrategy()
+ strategy.num_threads = args.cpus
+ strategy.allow_op_delay = False
+ build_strategy = fluid.BuildStrategy()
+ if args.reduce_strategy == "reduce":
+ build_strategy.reduce_strategy = fluid.BuildStrategy(
+ ).ReduceStrategy.Reduce
+ else:
+ build_strategy.reduce_strategy = fluid.BuildStrategy(
+ ).ReduceStrategy.AllReduce
+
+ avg_loss = train_args[0]
+
+ if args.update_method == "pserver":
+ # parameter server mode distributed training, merge
+ # gradients on local server, do not initialize
+ # ParallelExecutor with multi server all-reduce mode.
+ num_trainers = 1
+ trainer_id = 0
+
+ exe = fluid.ParallelExecutor(
+ True,
+ avg_loss.name,
+ main_program=train_prog,
+ exec_strategy=strategy,
+ build_strategy=build_strategy,
+ num_trainers=num_trainers,
+ trainer_id=trainer_id)
+
+ if not args.no_test:
+ if args.update_method == "pserver":
+ test_scope = None
+ else:
+ # NOTE: use an empty scope to avoid test exe using NCCLID
+ test_scope = fluid.Scope()
+ test_exe = fluid.ParallelExecutor(
+ True, main_program=test_prog, share_vars_from=exe)
+
+ pyreader = train_args[4]
+ for pass_id in range(args.pass_num):
+ num_samples = 0
+ iters = 0
+ start_time = time.time()
+ batch_id = 0
+ pyreader.start()
+ while True:
+ if iters == args.iterations:
+ break
+
+ if iters == args.skip_batch_num:
+ start_time = time.time()
+ num_samples = 0
+ fetch_list = [avg_loss.name]
+ acc_name_list = [v.name for v in train_args[2]]
+ fetch_list.extend(acc_name_list)
+
+ try:
+ fetch_ret = exe.run(fetch_list)
+ except fluid.core.EOFException as eof:
+ break
+ except fluid.core.EnforceNotMet as ex:
+ traceback.print_exc()
+ break
+ num_samples += args.batch_size * args.gpus
+
+ iters += 1
+ if batch_id % 1 == 0:
+ fetched_data = [np.mean(np.array(d)) for d in fetch_ret]
+ print("Pass %d, batch %d, loss %s, accucacys: %s" %
+ (pass_id, batch_id, fetched_data[0], fetched_data[1:]))
+ batch_id += 1
+
+ print_train_time(start_time, time.time(), num_samples)
+ pyreader.reset() # reset reader handle
+
+ if not args.no_test and test_args[2]:
+ test_feeder = None
+ test_ret = test_parallel(test_exe, test_args, args, test_prog,
+ test_feeder)
+ print("Pass: %d, Test Accuracy: %s\n" %
+ (pass_id, [np.mean(np.array(v)) for v in test_ret]))
+
+ startup_exe.close()
+ print("total train time: ", time.time() - over_all_start)
+
+
+def print_arguments(args):
+ print('----------- Configuration Arguments -----------')
+ for arg, value in sorted(vars(args).iteritems()):
+ print('%s: %s' % (arg, value))
+ print('------------------------------------------------')
+
+
+def print_train_time(start_time, end_time, num_samples):
+ train_elapsed = end_time - start_time
+ examples_per_sec = num_samples / train_elapsed
+ print('\nTotal examples: %d, total time: %.5f, %.5f examples/sed\n' %
+ (num_samples, train_elapsed, examples_per_sec))
+
+
+def print_paddle_envs():
+ print('----------- Configuration envs -----------')
+ for k in os.environ:
+ if "PADDLE_" in k:
+ print "ENV %s:%s" % (k, os.environ[k])
+ print('------------------------------------------------')
+
+
+def main():
+ args = parse_args()
+ print_arguments(args)
+ print_paddle_envs()
+ if args.no_random:
+ fluid.default_startup_program().random_seed = 1
+
+ # the unique trainer id, starting from 0, needed by trainer
+ # only
+ nccl_id_var, num_trainers, trainer_id = (
+ None, 1, int(os.getenv("PADDLE_TRAINER_ID", "0")))
+
+ train_prog = fluid.Program()
+ test_prog = fluid.Program()
+ startup_prog = fluid.Program()
+
+ train_args = list(get_model(args, True, train_prog, startup_prog))
+ test_args = list(get_model(args, False, test_prog, startup_prog))
+
+ all_args = [train_args, test_args, args]
+
+ if args.update_method == "pserver":
+ train_prog, startup_prog = dist_transpile(trainer_id, args, train_prog,
+ startup_prog)
+ if not train_prog:
+ raise Exception(
+ "Must configure correct environments to run dist train.")
+ all_args.extend([train_prog, test_prog, startup_prog])
+ if args.gpus > 1 and os.getenv("PADDLE_TRAINING_ROLE") == "TRAINER":
+ all_args.extend([nccl_id_var, num_trainers, trainer_id])
+ train_parallel(*all_args)
+ elif os.getenv("PADDLE_TRAINING_ROLE") == "PSERVER":
+ # start pserver with Executor
+ server_exe = fluid.Executor(fluid.CPUPlace())
+ server_exe.run(startup_prog)
+ server_exe.run(train_prog)
+ exit(0)
+
+ # for other update methods, use default programs
+ all_args.extend([train_prog, test_prog, startup_prog])
+
+ if args.update_method == "nccl2":
+ nccl_id_var, num_trainers, trainer_id = append_nccl2_prepare(
+ trainer_id, startup_prog)
+
+ all_args.extend([nccl_id_var, num_trainers, trainer_id])
+ train_parallel(*all_args)
+
+if __name__ == "__main__":
+ main()
diff --git a/fluid/image_classification/images/resnet50_32gpus-acc1.png b/fluid/image_classification/images/resnet50_32gpus-acc1.png
new file mode 100644
index 0000000000000000000000000000000000000000..6d4c478743d0e5af0a9d727c76b433849c6a81dc
Binary files /dev/null and b/fluid/image_classification/images/resnet50_32gpus-acc1.png differ
diff --git a/fluid/image_classification/reader.py b/fluid/image_classification/reader.py
index 3be7667b4a3a5d1c14b54ae88bbd085b3c32dadd..50be1cdef5ff3ad612d4d447a87174a767867a02 100644
--- a/fluid/image_classification/reader.py
+++ b/fluid/image_classification/reader.py
@@ -15,8 +15,6 @@ THREAD = 8
BUF_SIZE = 102400
DATA_DIR = 'data/ILSVRC2012'
-TRAIN_LIST = 'data/ILSVRC2012/train_list.txt'
-TEST_LIST = 'data/ILSVRC2012/val_list.txt'
img_mean = np.array([0.485, 0.456, 0.406]).reshape((3, 1, 1))
img_std = np.array([0.229, 0.224, 0.225]).reshape((3, 1, 1))
@@ -131,19 +129,35 @@ def _reader_creator(file_list,
mode,
shuffle=False,
color_jitter=False,
- rotate=False):
+ rotate=False,
+ data_dir=DATA_DIR):
def reader():
with open(file_list) as flist:
- lines = [line.strip() for line in flist]
+ full_lines = [line.strip() for line in flist]
if shuffle:
- np.random.shuffle(lines)
+ np.random.shuffle(full_lines)
+ if mode == 'train' and os.getenv('PADDLE_TRAINING_ROLE'):
+ # distributed mode if the env var `PADDLE_TRAINING_ROLE` exits
+ trainer_id = int(os.getenv("PADDLE_TRAINER_ID", "0"))
+ trainer_count = int(os.getenv("PADDLE_TRAINERS", "1"))
+ per_node_lines = len(full_lines) / trainer_count
+ lines = full_lines[trainer_id * per_node_lines:(trainer_id + 1)
+ * per_node_lines]
+ print(
+ "read images from %d, length: %d, lines length: %d, total: %d"
+ % (trainer_id * per_node_lines, per_node_lines, len(lines),
+ len(full_lines)))
+ else:
+ lines = full_lines
+
for line in lines:
if mode == 'train' or mode == 'val':
img_path, label = line.split()
- img_path = os.path.join(DATA_DIR, img_path)
+ img_path = img_path.replace("JPEG", "jpeg")
+ img_path = os.path.join(data_dir, img_path)
yield img_path, int(label)
elif mode == 'test':
- img_path = os.path.join(DATA_DIR, line)
+ img_path = os.path.join(data_dir, line)
yield [img_path]
mapper = functools.partial(
@@ -152,14 +166,17 @@ def _reader_creator(file_list,
return paddle.reader.xmap_readers(mapper, reader, THREAD, BUF_SIZE)
-def train(file_list=TRAIN_LIST):
+def train(data_dir=DATA_DIR):
+ file_list = os.path.join(data_dir, 'train_list.txt')
return _reader_creator(
- file_list, 'train', shuffle=True, color_jitter=False, rotate=False)
+ file_list, 'train', shuffle=True, color_jitter=False, rotate=False, data_dir=data_dir)
-def val(file_list=TEST_LIST):
- return _reader_creator(file_list, 'val', shuffle=False)
+def val(data_dir=DATA_DIR):
+ file_list = os.path.join(data_dir, 'val_list.txt')
+ return _reader_creator(file_list, 'val', shuffle=False, data_dir=data_dir)
-def test(file_list=TEST_LIST):
- return _reader_creator(file_list, 'test', shuffle=False)
+def test(data_dir=DATA_DIR):
+ file_list = os.path.join(data_dir, 'val_list.txt')
+ return _reader_creator(file_list, 'test', shuffle=False, data_dir=data_dir)
diff --git a/fluid/image_classification/train.py b/fluid/image_classification/train.py
index 75a6cfb035de96d4287f31ffa849d69b4b40b1a8..bfc5f8b1412a11606d54b020f29bef969bae2a62 100644
--- a/fluid/image_classification/train.py
+++ b/fluid/image_classification/train.py
@@ -33,6 +33,7 @@ add_arg('lr', float, 0.1, "set learning rate.")
add_arg('lr_strategy', str, "piecewise_decay", "Set the learning rate decay strategy.")
add_arg('model', str, "SE_ResNeXt50_32x4d", "Set the network to use.")
add_arg('enable_ce', bool, False, "If set True, enable continuous evaluation job.")
+add_arg('data_dir' str, "./data/ILSVRC2012", "The ImageNet dataset root dir.")
# yapf: enable
model_list = [m for m in dir(models) if "__" not in m]