diff --git a/fluid/image_classification/dist_train/README.md b/fluid/image_classification/dist_train/README.md new file mode 100644 index 0000000000000000000000000000000000000000..02bbea17f423fe2e16fd3115058ba92805a313ab --- /dev/null +++ b/fluid/image_classification/dist_train/README.md @@ -0,0 +1,113 @@ +# Distributed Image Classification Models Training + +This folder contains implementations of **Image Classification Models**, they are designed to support +large-scaled distributed training with two distributed mode: parameter server mode and NCCL2(Nvidia NCCL2 communication library) collective mode. + +## Getting Started + +Before getting started, please make sure you have go throught the imagenet [Data Preparation](../README.md#data-preparation). + +1. The entrypoint file is `dist_train.py`, some important flags are as follows: + + - `model`, the model to run with, such as `ResNet50`, `ResNet101` and etc.. + - `batch_size`, the batch_size per device. + - `update_method`, specify the update method, can choose from local, pserver or nccl2. + - `device`, use CPU or GPU device. + - `gpus`, the GPU device count that the process used. + + you can check out more details of the flags by `python dist_train.py --help`. + +1. Runtime configurations + + We use the environment variable to distinguish the different training role of a distributed training job. + + - `PADDLE_TRAINING_ROLE`, the current training role, should be in [PSERVER, TRAINER]. + - `PADDLE_TRAINERS`, the trainer count of a job. + - `PADDLE_CURRENT_IP`, the current instance IP. + - `PADDLE_PSERVER_IPS`, the parameter server IP list, separated by "," only be used with update_method is pserver. + - `PADDLE_TRAINER_ID`, the unique trainer ID of a job, the ranging is [0, PADDLE_TRAINERS). + - `PADDLE_PSERVER_PORT`, the port of the parameter pserver listened on. + - `PADDLE_TRAINER_IPS`, the trainer IP list, separated by ",", only be used with upadte_method is nccl2. + +### Parameter Server Mode + +In this example, we launched 4 parameter server instances and 4 trainer instances in the cluster: + +1. launch parameter server process + + ``` python + PADDLE_TRAINING_ROLE=PSERVER \ + PADDLE_TRAINERS=4 \ + PADDLE_PSERVER_IPS=192.168.0.100,192.168.0.101,192.168.0.102,192.168.0.103 \ + PADDLE_CURRENT_IP=192.168.0.100 \ + PADDLE_PSERVER_PORT=7164 \ + python dist_train.py \ + --model=ResNet50 \ + --batch_size=32 \ + --update_method=pserver \ + --device=CPU \ + --data_dir=../data/ILSVRC2012 + ``` + +1. launch trainer process + + ``` python + PADDLE_TRAINING_ROLE=PSERVER \ + PADDLE_TRAINERS=4 \ + PADDLE_PSERVER_IPS=192.168.0.100,192.168.0.101,192.168.0.102,192.168.0.103 \ + PADDLE_TRAINER_ID=0 \ + PADDLE_PSERVER_PORT=7164 \ + python dist_train.py \ + --model=ResNet50 \ + --batch_size=32 \ + --update_method=pserver \ + --device=GPU \ + --data_dir=../data/ILSVRC2012 + + ``` + +### NCCL2 Collective Mode + +1. launch trainer process + + ``` python + PADDLE_TRAINING_ROLE=TRAINER \ + PADDLE_TRAINERS=4 \ + PADDLE_TRAINER_IPS=192.168.0.100,192.168.0.101,192.168.0.102,192.168.0.103 \ + PADDLE_TRAINER_ID=0 \ + python dist_train.py \ + --model=ResNet50 \ + --batch_size=32 \ + --update_method=pserver \ + --device=GPU \ + --data_dir=../data/ILSVRC2012 + ``` + +### Visualize the Training Process + +It's easy to draw the learning curve accroding to the training logs, for example, +the logs of ResNet50 is as follows: + +``` text +Pass 0, batch 0, loss 7.0336914, accucacys: [0.0, 0.00390625] +Pass 0, batch 1, loss 7.094781, accucacys: [0.0, 0.0] +Pass 0, batch 2, loss 7.007068, accucacys: [0.0, 0.0078125] +Pass 0, batch 3, loss 7.1056547, accucacys: [0.00390625, 0.00390625] +Pass 0, batch 4, loss 7.133543, accucacys: [0.0, 0.0078125] +Pass 0, batch 5, loss 7.3055463, accucacys: [0.0078125, 0.01171875] +Pass 0, batch 6, loss 7.341838, accucacys: [0.0078125, 0.01171875] +Pass 0, batch 7, loss 7.290557, accucacys: [0.0, 0.0] +Pass 0, batch 8, loss 7.264951, accucacys: [0.0, 0.00390625] +Pass 0, batch 9, loss 7.43522, accucacys: [0.00390625, 0.00390625] +``` + +The training accucacys top1 of local training, distributed training with NCCL2 and parameter server architecture on the ResNet50 model are shown in the below figure: + +

+
+Training acc1 curves +

+ +### Performance + +TBD \ No newline at end of file diff --git a/fluid/image_classification/dist_train/__init__.py b/fluid/image_classification/dist_train/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/fluid/image_classification/dist_train/args.py b/fluid/image_classification/dist_train/args.py new file mode 100644 index 0000000000000000000000000000000000000000..fff9fd11a194abf05e49b76fb89125036b2c893d --- /dev/null +++ b/fluid/image_classification/dist_train/args.py @@ -0,0 +1,118 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse + +__all__ = ['parse_args', ] + +BENCHMARK_MODELS = [ + "ResNet50", "ResNet101", "ResNet152" +] + + +def parse_args(): + parser = argparse.ArgumentParser('Distributed Image Classification Training.') + parser.add_argument( + '--model', + type=str, + choices=BENCHMARK_MODELS, + default='resnet', + help='The model to run benchmark with.') + parser.add_argument( + '--batch_size', type=int, default=32, help='The minibatch size.') + # args related to learning rate + parser.add_argument( + '--learning_rate', type=float, default=0.001, help='The learning rate.') + # TODO(wuyi): add "--use_fake_data" option back. + parser.add_argument( + '--skip_batch_num', + type=int, + default=5, + help='The first num of minibatch num to skip, for better performance test' + ) + parser.add_argument( + '--iterations', type=int, default=80, help='The number of minibatches.') + parser.add_argument( + '--pass_num', type=int, default=100, help='The number of passes.') + parser.add_argument( + '--data_format', + type=str, + default='NCHW', + choices=['NCHW', 'NHWC'], + help='The data data_format, now only support NCHW.') + parser.add_argument( + '--device', + type=str, + default='GPU', + choices=['CPU', 'GPU'], + help='The device type.') + parser.add_argument( + '--gpus', + type=int, + default=1, + help='If gpus > 1, will use ParallelExecutor to run, else use Executor.') + # this option is available only for vgg and resnet. + parser.add_argument( + '--cpus', + type=int, + default=1, + help='If cpus > 1, will set ParallelExecutor to use multiple threads.') + parser.add_argument( + '--data_set', + type=str, + default='flowers', + choices=['cifar10', 'flowers', 'imagenet'], + help='Optional dataset for benchmark.') + parser.add_argument( + '--no_test', + action='store_true', + help='If set, do not test the testset during training.') + parser.add_argument( + '--memory_optimize', + action='store_true', + help='If set, optimize runtime memory before start.') + parser.add_argument( + '--update_method', + type=str, + default='local', + choices=['local', 'pserver', 'nccl2'], + help='Choose parameter update method, can be local, pserver, nccl2.') + parser.add_argument( + '--no_split_var', + action='store_true', + default=False, + help='Whether split variables into blocks when update_method is pserver') + parser.add_argument( + '--async_mode', + action='store_true', + default=False, + help='Whether start pserver in async mode to support ASGD') + parser.add_argument( + '--no_random', + action='store_true', + help='If set, keep the random seed and do not shuffle the data.') + parser.add_argument( + '--reduce_strategy', + type=str, + choices=['reduce', 'all_reduce'], + default='all_reduce', + help='Specify the reduce strategy, can be reduce, all_reduce') + parser.add_argument( + '--data_dir', + type=str, + default="../data/ILSVRC2012", + help="The ImageNet dataset root dir." + ) + args = parser.parse_args() + return args diff --git a/fluid/image_classification/dist_train/dist_train.py b/fluid/image_classification/dist_train/dist_train.py new file mode 100644 index 0000000000000000000000000000000000000000..8f3953d1320441030c3cdee899ec85ffc8f8f401 --- /dev/null +++ b/fluid/image_classification/dist_train/dist_train.py @@ -0,0 +1,363 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import time +import os +import traceback + +import numpy as np + +import paddle +import paddle.fluid as fluid +import paddle.fluid.core as core +import sys +sys.path.append("..") +import models +from args import * +from reader import train, val + +def get_model(args, is_train, main_prog, startup_prog): + pyreader = None + class_dim = 1000 + if args.data_format == 'NCHW': + dshape = [3, 224, 224] + else: + dshape = [224, 224, 3] + if is_train: + reader = train(data_dir=args.data_dir) + else: + reader = val(data_dir=args.data_dir) + + trainer_count = int(os.getenv("PADDLE_TRAINERS", "1")) + with fluid.program_guard(main_prog, startup_prog): + with fluid.unique_name.guard(): + pyreader = fluid.layers.py_reader( + capacity=args.batch_size * args.gpus, + shapes=([-1] + dshape, (-1, 1)), + dtypes=('float32', 'int64'), + name="train_reader" if is_train else "test_reader", + use_double_buffer=True) + input, label = fluid.layers.read_file(pyreader) + model_def = models.__dict__[args.model]() + predict = model_def.net(input, class_dim=class_dim) + + cost = fluid.layers.cross_entropy(input=predict, label=label) + avg_cost = fluid.layers.mean(x=cost) + + batch_acc1 = fluid.layers.accuracy(input=predict, label=label, k=1) + batch_acc5 = fluid.layers.accuracy(input=predict, label=label, k=5) + + # configure optimize + optimizer = None + if is_train: + + total_images = 1281167 / trainer_count + + step = int(total_images / (args.batch_size * args.gpus) + 1) + epochs = [30, 60, 90] + bd = [step * e for e in epochs] + base_lr = args.learning_rate + lr = [] + lr = [base_lr * (0.1**i) for i in range(len(bd) + 1)] + optimizer = fluid.optimizer.Momentum( + learning_rate=fluid.layers.piecewise_decay( + boundaries=bd, values=lr), + momentum=0.9, + regularization=fluid.regularizer.L2Decay(1e-4)) + optimizer.minimize(avg_cost) + + if args.memory_optimize: + fluid.memory_optimize(main_prog) + + batched_reader = None + pyreader.decorate_paddle_reader( + paddle.batch( + reader if args.no_random else paddle.reader.shuffle( + reader, buf_size=5120), + batch_size=args.batch_size)) + + return avg_cost, optimizer, [batch_acc1, + batch_acc5], batched_reader, pyreader + +def append_nccl2_prepare(trainer_id, startup_prog): + if trainer_id >= 0: + # append gen_nccl_id at the end of startup program + trainer_id = int(os.getenv("PADDLE_TRAINER_ID")) + port = os.getenv("PADDLE_PSERVER_PORT") + worker_ips = os.getenv("PADDLE_TRAINER_IPS") + worker_endpoints = [] + for ip in worker_ips.split(","): + worker_endpoints.append(':'.join([ip, port])) + num_trainers = len(worker_endpoints) + current_endpoint = os.getenv("PADDLE_CURRENT_IP") + ":" + port + worker_endpoints.remove(current_endpoint) + + nccl_id_var = startup_prog.global_block().create_var( + name="NCCLID", + persistable=True, + type=fluid.core.VarDesc.VarType.RAW) + startup_prog.global_block().append_op( + type="gen_nccl_id", + inputs={}, + outputs={"NCCLID": nccl_id_var}, + attrs={ + "endpoint": current_endpoint, + "endpoint_list": worker_endpoints, + "trainer_id": trainer_id + }) + return nccl_id_var, num_trainers, trainer_id + else: + raise Exception("must set positive PADDLE_TRAINER_ID env variables for " + "nccl-based dist train.") + + +def dist_transpile(trainer_id, args, train_prog, startup_prog): + if trainer_id < 0: + return None, None + + # the port of all pservers, needed by both trainer and pserver + port = os.getenv("PADDLE_PSERVER_PORT", "6174") + # comma separated ips of all pservers, needed by trainer and + # pserver + pserver_ips = os.getenv("PADDLE_PSERVER_IPS", "") + eplist = [] + for ip in pserver_ips.split(","): + eplist.append(':'.join([ip, port])) + pserver_endpoints = ",".join(eplist) + # total number of workers/trainers in the job, needed by + # trainer and pserver + trainers = int(os.getenv("PADDLE_TRAINERS")) + # the IP of the local machine, needed by pserver only + current_endpoint = os.getenv("PADDLE_CURRENT_IP", "") + ":" + port + # the role, should be either PSERVER or TRAINER + training_role = os.getenv("PADDLE_TRAINING_ROLE") + + config = fluid.DistributeTranspilerConfig() + config.slice_var_up = not args.no_split_var + t = fluid.DistributeTranspiler(config=config) + t.transpile( + trainer_id, + # NOTE: *MUST* use train_prog, for we are using with guard to + # generate different program for train and test. + program=train_prog, + pservers=pserver_endpoints, + trainers=trainers, + sync_mode=not args.async_mode, + startup_program=startup_prog) + if training_role == "PSERVER": + pserver_program = t.get_pserver_program(current_endpoint) + pserver_startup_program = t.get_startup_program( + current_endpoint, pserver_program, startup_program=startup_prog) + return pserver_program, pserver_startup_program + elif training_role == "TRAINER": + train_program = t.get_trainer_program() + return train_program, startup_prog + else: + raise ValueError( + 'PADDLE_TRAINING_ROLE environment variable must be either TRAINER or PSERVER' + ) + + +def test_parallel(exe, test_args, args, test_prog, feeder): + acc_evaluators = [] + for i in xrange(len(test_args[2])): + acc_evaluators.append(fluid.metrics.Accuracy()) + + to_fetch = [v.name for v in test_args[2]] + test_args[4].start() + while True: + try: + acc_rets = exe.run(fetch_list=to_fetch) + for i, e in enumerate(acc_evaluators): + e.update( + value=np.array(acc_rets[i]), weight=args.batch_size) + except fluid.core.EOFException as eof: + test_args[4].reset() + break + + return [e.eval() for e in acc_evaluators] + + +# NOTE: only need to benchmark using parallelexe +def train_parallel(train_args, test_args, args, train_prog, test_prog, + startup_prog, nccl_id_var, num_trainers, trainer_id): + over_all_start = time.time() + place = core.CPUPlace() if args.device == 'CPU' else core.CUDAPlace(0) + feeder = None + + if nccl_id_var and trainer_id == 0: + #FIXME(wuyi): wait other trainer to start listening + time.sleep(30) + + startup_exe = fluid.Executor(place) + startup_exe.run(startup_prog) + strategy = fluid.ExecutionStrategy() + strategy.num_threads = args.cpus + strategy.allow_op_delay = False + build_strategy = fluid.BuildStrategy() + if args.reduce_strategy == "reduce": + build_strategy.reduce_strategy = fluid.BuildStrategy( + ).ReduceStrategy.Reduce + else: + build_strategy.reduce_strategy = fluid.BuildStrategy( + ).ReduceStrategy.AllReduce + + avg_loss = train_args[0] + + if args.update_method == "pserver": + # parameter server mode distributed training, merge + # gradients on local server, do not initialize + # ParallelExecutor with multi server all-reduce mode. + num_trainers = 1 + trainer_id = 0 + + exe = fluid.ParallelExecutor( + True, + avg_loss.name, + main_program=train_prog, + exec_strategy=strategy, + build_strategy=build_strategy, + num_trainers=num_trainers, + trainer_id=trainer_id) + + if not args.no_test: + if args.update_method == "pserver": + test_scope = None + else: + # NOTE: use an empty scope to avoid test exe using NCCLID + test_scope = fluid.Scope() + test_exe = fluid.ParallelExecutor( + True, main_program=test_prog, share_vars_from=exe) + + pyreader = train_args[4] + for pass_id in range(args.pass_num): + num_samples = 0 + iters = 0 + start_time = time.time() + batch_id = 0 + pyreader.start() + while True: + if iters == args.iterations: + break + + if iters == args.skip_batch_num: + start_time = time.time() + num_samples = 0 + fetch_list = [avg_loss.name] + acc_name_list = [v.name for v in train_args[2]] + fetch_list.extend(acc_name_list) + + try: + fetch_ret = exe.run(fetch_list) + except fluid.core.EOFException as eof: + break + except fluid.core.EnforceNotMet as ex: + traceback.print_exc() + break + num_samples += args.batch_size * args.gpus + + iters += 1 + if batch_id % 1 == 0: + fetched_data = [np.mean(np.array(d)) for d in fetch_ret] + print("Pass %d, batch %d, loss %s, accucacys: %s" % + (pass_id, batch_id, fetched_data[0], fetched_data[1:])) + batch_id += 1 + + print_train_time(start_time, time.time(), num_samples) + pyreader.reset() # reset reader handle + + if not args.no_test and test_args[2]: + test_feeder = None + test_ret = test_parallel(test_exe, test_args, args, test_prog, + test_feeder) + print("Pass: %d, Test Accuracy: %s\n" % + (pass_id, [np.mean(np.array(v)) for v in test_ret])) + + startup_exe.close() + print("total train time: ", time.time() - over_all_start) + + +def print_arguments(args): + print('----------- Configuration Arguments -----------') + for arg, value in sorted(vars(args).iteritems()): + print('%s: %s' % (arg, value)) + print('------------------------------------------------') + + +def print_train_time(start_time, end_time, num_samples): + train_elapsed = end_time - start_time + examples_per_sec = num_samples / train_elapsed + print('\nTotal examples: %d, total time: %.5f, %.5f examples/sed\n' % + (num_samples, train_elapsed, examples_per_sec)) + + +def print_paddle_envs(): + print('----------- Configuration envs -----------') + for k in os.environ: + if "PADDLE_" in k: + print "ENV %s:%s" % (k, os.environ[k]) + print('------------------------------------------------') + + +def main(): + args = parse_args() + print_arguments(args) + print_paddle_envs() + if args.no_random: + fluid.default_startup_program().random_seed = 1 + + # the unique trainer id, starting from 0, needed by trainer + # only + nccl_id_var, num_trainers, trainer_id = ( + None, 1, int(os.getenv("PADDLE_TRAINER_ID", "0"))) + + train_prog = fluid.Program() + test_prog = fluid.Program() + startup_prog = fluid.Program() + + train_args = list(get_model(args, True, train_prog, startup_prog)) + test_args = list(get_model(args, False, test_prog, startup_prog)) + + all_args = [train_args, test_args, args] + + if args.update_method == "pserver": + train_prog, startup_prog = dist_transpile(trainer_id, args, train_prog, + startup_prog) + if not train_prog: + raise Exception( + "Must configure correct environments to run dist train.") + all_args.extend([train_prog, test_prog, startup_prog]) + if args.gpus > 1 and os.getenv("PADDLE_TRAINING_ROLE") == "TRAINER": + all_args.extend([nccl_id_var, num_trainers, trainer_id]) + train_parallel(*all_args) + elif os.getenv("PADDLE_TRAINING_ROLE") == "PSERVER": + # start pserver with Executor + server_exe = fluid.Executor(fluid.CPUPlace()) + server_exe.run(startup_prog) + server_exe.run(train_prog) + exit(0) + + # for other update methods, use default programs + all_args.extend([train_prog, test_prog, startup_prog]) + + if args.update_method == "nccl2": + nccl_id_var, num_trainers, trainer_id = append_nccl2_prepare( + trainer_id, startup_prog) + + all_args.extend([nccl_id_var, num_trainers, trainer_id]) + train_parallel(*all_args) + +if __name__ == "__main__": + main() diff --git a/fluid/image_classification/images/resnet50_32gpus-acc1.png b/fluid/image_classification/images/resnet50_32gpus-acc1.png new file mode 100644 index 0000000000000000000000000000000000000000..6d4c478743d0e5af0a9d727c76b433849c6a81dc Binary files /dev/null and b/fluid/image_classification/images/resnet50_32gpus-acc1.png differ diff --git a/fluid/image_classification/reader.py b/fluid/image_classification/reader.py index 3be7667b4a3a5d1c14b54ae88bbd085b3c32dadd..50be1cdef5ff3ad612d4d447a87174a767867a02 100644 --- a/fluid/image_classification/reader.py +++ b/fluid/image_classification/reader.py @@ -15,8 +15,6 @@ THREAD = 8 BUF_SIZE = 102400 DATA_DIR = 'data/ILSVRC2012' -TRAIN_LIST = 'data/ILSVRC2012/train_list.txt' -TEST_LIST = 'data/ILSVRC2012/val_list.txt' img_mean = np.array([0.485, 0.456, 0.406]).reshape((3, 1, 1)) img_std = np.array([0.229, 0.224, 0.225]).reshape((3, 1, 1)) @@ -131,19 +129,35 @@ def _reader_creator(file_list, mode, shuffle=False, color_jitter=False, - rotate=False): + rotate=False, + data_dir=DATA_DIR): def reader(): with open(file_list) as flist: - lines = [line.strip() for line in flist] + full_lines = [line.strip() for line in flist] if shuffle: - np.random.shuffle(lines) + np.random.shuffle(full_lines) + if mode == 'train' and os.getenv('PADDLE_TRAINING_ROLE'): + # distributed mode if the env var `PADDLE_TRAINING_ROLE` exits + trainer_id = int(os.getenv("PADDLE_TRAINER_ID", "0")) + trainer_count = int(os.getenv("PADDLE_TRAINERS", "1")) + per_node_lines = len(full_lines) / trainer_count + lines = full_lines[trainer_id * per_node_lines:(trainer_id + 1) + * per_node_lines] + print( + "read images from %d, length: %d, lines length: %d, total: %d" + % (trainer_id * per_node_lines, per_node_lines, len(lines), + len(full_lines))) + else: + lines = full_lines + for line in lines: if mode == 'train' or mode == 'val': img_path, label = line.split() - img_path = os.path.join(DATA_DIR, img_path) + img_path = img_path.replace("JPEG", "jpeg") + img_path = os.path.join(data_dir, img_path) yield img_path, int(label) elif mode == 'test': - img_path = os.path.join(DATA_DIR, line) + img_path = os.path.join(data_dir, line) yield [img_path] mapper = functools.partial( @@ -152,14 +166,17 @@ def _reader_creator(file_list, return paddle.reader.xmap_readers(mapper, reader, THREAD, BUF_SIZE) -def train(file_list=TRAIN_LIST): +def train(data_dir=DATA_DIR): + file_list = os.path.join(data_dir, 'train_list.txt') return _reader_creator( - file_list, 'train', shuffle=True, color_jitter=False, rotate=False) + file_list, 'train', shuffle=True, color_jitter=False, rotate=False, data_dir=data_dir) -def val(file_list=TEST_LIST): - return _reader_creator(file_list, 'val', shuffle=False) +def val(data_dir=DATA_DIR): + file_list = os.path.join(data_dir, 'val_list.txt') + return _reader_creator(file_list, 'val', shuffle=False, data_dir=data_dir) -def test(file_list=TEST_LIST): - return _reader_creator(file_list, 'test', shuffle=False) +def test(data_dir=DATA_DIR): + file_list = os.path.join(data_dir, 'val_list.txt') + return _reader_creator(file_list, 'test', shuffle=False, data_dir=data_dir) diff --git a/fluid/image_classification/train.py b/fluid/image_classification/train.py index 75a6cfb035de96d4287f31ffa849d69b4b40b1a8..bfc5f8b1412a11606d54b020f29bef969bae2a62 100644 --- a/fluid/image_classification/train.py +++ b/fluid/image_classification/train.py @@ -33,6 +33,7 @@ add_arg('lr', float, 0.1, "set learning rate.") add_arg('lr_strategy', str, "piecewise_decay", "Set the learning rate decay strategy.") add_arg('model', str, "SE_ResNeXt50_32x4d", "Set the network to use.") add_arg('enable_ce', bool, False, "If set True, enable continuous evaluation job.") +add_arg('data_dir' str, "./data/ILSVRC2012", "The ImageNet dataset root dir.") # yapf: enable model_list = [m for m in dir(models) if "__" not in m]