提交 934f1f67 编写于 作者: T typhoonzero

refine dist image classification

上级 9a4f5786
...@@ -7,13 +7,15 @@ large-scaled distributed training with two distributed mode: parameter server mo ...@@ -7,13 +7,15 @@ large-scaled distributed training with two distributed mode: parameter server mo
Before getting started, please make sure you have go throught the imagenet [Data Preparation](../README.md#data-preparation). Before getting started, please make sure you have go throught the imagenet [Data Preparation](../README.md#data-preparation).
1. The entrypoint file is `dist_train.py`, some important flags are as follows: 1. The entrypoint file is `dist_train.py`, the commandline arguments are almost the same as the original `train.py`, with the following arguments specific to distributed training.
- `model`, the model to run with, default is the fine tune model `DistResnet`.
- `batch_size`, the batch_size per device.
- `update_method`, specify the update method, can choose from local, pserver or nccl2. - `update_method`, specify the update method, can choose from local, pserver or nccl2.
- `device`, use CPU or GPU device. - `multi_batch_repeat`, set this greater than 1 to merge batches before pushing gradients to pservers.
- `gpus`, the GPU device count that the process used. - `start_test_pass`, when to start running tests.
- `num_threads`, how many threads will be used for ParallelExecutor.
- `split_var`, in pserver mode, whether to split one parameter to several pservers, default True.
- `async_mode`, do async training, defalt False.
- `reduce_strategy`, choose from "reduce", "allreduce".
you can check out more details of the flags by `python dist_train.py --help`. you can check out more details of the flags by `python dist_train.py --help`.
...@@ -21,66 +23,27 @@ Before getting started, please make sure you have go throught the imagenet [Data ...@@ -21,66 +23,27 @@ Before getting started, please make sure you have go throught the imagenet [Data
We use the environment variable to distinguish the different training role of a distributed training job. We use the environment variable to distinguish the different training role of a distributed training job.
- `PADDLE_TRAINING_ROLE`, the current training role, should be in [PSERVER, TRAINER]. - General envs:
- `PADDLE_TRAINERS`, the trainer count of a job. - `PADDLE_TRAINER_ID`, the unique trainer ID of a job, the ranging is [0, PADDLE_TRAINERS).
- `PADDLE_CURRENT_IP`, the current instance IP. - `PADDLE_TRAINERS_NUM`, the trainer count of a distributed job.
- `PADDLE_PSERVER_IPS`, the parameter server IP list, separated by "," only be used with update_method is pserver. - `PADDLE_CURRENT_ENDPOINT`, current process endpoint.
- `PADDLE_TRAINER_ID`, the unique trainer ID of a job, the ranging is [0, PADDLE_TRAINERS). - Pserver mode:
- `PADDLE_PSERVER_PORT`, the port of the parameter pserver listened on. - `PADDLE_TRAINING_ROLE`, the current training role, should be in [PSERVER, TRAINER].
- `PADDLE_TRAINER_IPS`, the trainer IP list, separated by ",", only be used with upadte_method is nccl2. - `PADDLE_PSERVER_ENDPOINTS`, the parameter server endpoint list, separated by ",".
- NCCL2 mode:
### Parameter Server Mode - `PADDLE_TRAINER_ENDPOINTS`, endpoint list for each worker, separated by ",".
In this example, we launched 4 parameter server instances and 4 trainer instances in the cluster: ### Try Out Different Distributed Training Modes
1. launch parameter server process You can test if distributed training works on a single node before deploying to the "real" cluster.
``` bash ***NOTE: for best performance, we recommend using multi-process mode, see No.3. And together with fp16.***
PADDLE_TRAINING_ROLE=PSERVER \
PADDLE_TRAINERS=4 \ 1. simply run `python dist_train.py` to start local training with default configuratioins.
PADDLE_PSERVER_IPS=192.168.0.100,192.168.0.101,192.168.0.102,192.168.0.103 \ 2. for pserver mode, run `bash run_ps_mode.sh` to start 2 pservers and 2 trainers, these 2 trainers
PADDLE_CURRENT_IP=192.168.0.100 \ will use GPU 0 and 1 to simulate 2 workers.
PADDLE_PSERVER_PORT=7164 \ 3. for nccl2 mode, run `bash run_nccl2_mode.sh` to start 2 workers.
python dist_train.py \ 4. for local/distributed multi-process mode, run `run_mp_mode.sh` (this test use 4 GPUs).
--model=DistResnet \
--batch_size=32 \
--update_method=pserver \
--device=CPU \
--data_dir=../data/ILSVRC2012
```
1. launch trainer process
``` bash
PADDLE_TRAINING_ROLE=TRAINER \
PADDLE_TRAINERS=4 \
PADDLE_PSERVER_IPS=192.168.0.100,192.168.0.101,192.168.0.102,192.168.0.103 \
PADDLE_TRAINER_ID=0 \
PADDLE_PSERVER_PORT=7164 \
python dist_train.py \
--model=DistResnet \
--batch_size=32 \
--update_method=pserver \
--device=GPU \
--data_dir=../data/ILSVRC2012
```
### NCCL2 Collective Mode
1. launch trainer process
``` bash
PADDLE_TRAINING_ROLE=TRAINER \
PADDLE_TRAINERS=4 \
PADDLE_TRAINER_IPS=192.168.0.100,192.168.0.101,192.168.0.102,192.168.0.103 \
PADDLE_TRAINER_ID=0 \
python dist_train.py \
--model=DistResnet \
--batch_size=32 \
--update_method=nccl2 \
--device=GPU \
--data_dir=../data/ILSVRC2012
```
### Visualize the Training Process ### Visualize the Training Process
...@@ -88,16 +51,10 @@ It's easy to draw the learning curve accroding to the training logs, for example ...@@ -88,16 +51,10 @@ It's easy to draw the learning curve accroding to the training logs, for example
the logs of ResNet50 is as follows: the logs of ResNet50 is as follows:
``` text ``` text
Pass 0, batch 0, loss 7.0336914, accucacys: [0.0, 0.00390625] Pass 0, batch 30, loss 7.569439, acc1: 0.0125, acc5: 0.0125, avg batch time 0.1720
Pass 0, batch 1, loss 7.094781, accucacys: [0.0, 0.0] Pass 0, batch 60, loss 7.027379, acc1: 0.0, acc5: 0.0, avg batch time 0.1551
Pass 0, batch 2, loss 7.007068, accucacys: [0.0, 0.0078125] Pass 0, batch 90, loss 6.819984, acc1: 0.0, acc5: 0.0125, avg batch time 0.1492
Pass 0, batch 3, loss 7.1056547, accucacys: [0.00390625, 0.00390625] Pass 0, batch 120, loss 6.9076853, acc1: 0.0, acc5: 0.0125, avg batch time 0.1464
Pass 0, batch 4, loss 7.133543, accucacys: [0.0, 0.0078125]
Pass 0, batch 5, loss 7.3055463, accucacys: [0.0078125, 0.01171875]
Pass 0, batch 6, loss 7.341838, accucacys: [0.0078125, 0.01171875]
Pass 0, batch 7, loss 7.290557, accucacys: [0.0, 0.0]
Pass 0, batch 8, loss 7.264951, accucacys: [0.0, 0.00390625]
Pass 0, batch 9, loss 7.43522, accucacys: [0.00390625, 0.00390625]
``` ```
The below figure shows top 1 train accuracy for local training with 8 GPUs and distributed training The below figure shows top 1 train accuracy for local training with 8 GPUs and distributed training
......
import paddle.fluid as fluid
def copyback_repeat_bn_params(main_prog):
repeat_vars = set()
for op in main_prog.global_block().ops:
if op.type == "batch_norm":
repeat_vars.add(op.input("Mean")[0])
repeat_vars.add(op.input("Variance")[0])
for vname in repeat_vars:
real_var = fluid.global_scope().find_var("%s.repeat.0" % vname).get_tensor()
orig_var = fluid.global_scope().find_var(vname).get_tensor()
orig_var.set(np.array(real_var), fluid.CUDAPlace(0)) # test on GPU0
def append_bn_repeat_init_op(main_prog, startup_prog, num_repeats):
repeat_vars = set()
for op in main_prog.global_block().ops:
if op.type == "batch_norm":
repeat_vars.add(op.input("Mean")[0])
repeat_vars.add(op.input("Variance")[0])
for i in range(num_repeats):
for op in startup_prog.global_block().ops:
if op.type == "fill_constant":
for oname in op.output_arg_names:
if oname in repeat_vars:
var = startup_prog.global_block().var(oname)
repeat_var_name = "%s.repeat.%d" % (oname, i)
repeat_var = startup_prog.global_block().create_var(
name=repeat_var_name,
type=var.type,
dtype=var.dtype,
shape=var.shape,
persistable=var.persistable
)
main_prog.global_block()._clone_variable(repeat_var)
startup_prog.global_block().append_op(
type="fill_constant",
inputs={},
outputs={"Out": repeat_var},
attrs=op.all_attrs()
)
...@@ -16,6 +16,8 @@ import argparse ...@@ -16,6 +16,8 @@ import argparse
import time import time
import os import os
import traceback import traceback
import functools
import subprocess
import numpy as np import numpy as np
...@@ -28,127 +30,115 @@ sys.path.append("..") ...@@ -28,127 +30,115 @@ sys.path.append("..")
import models import models
import utils import utils
from reader import train, val from reader import train, val
from utility import add_arguments, print_arguments
from batch_merge import copyback_repeat_bn_params, append_bn_repeat_init_op
from dist_utils import pserver_prepare, nccl2_prepare
from env import dist_env
def parse_args(): def parse_args():
parser = argparse.ArgumentParser('Distributed Image Classification Training.') parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument( add_arg = functools.partial(add_arguments, argparser=parser)
'--model', # yapf: disable
type=str, add_arg('batch_size', int, 256, "Minibatch size.")
default='DistResNet', add_arg('use_gpu', bool, True, "Whether to use GPU or not.")
help='The model to run.') add_arg('total_images', int, 1281167, "Training image number.")
parser.add_argument( add_arg('num_epochs', int, 120, "number of epochs.")
'--batch_size', type=int, default=32, help='The minibatch size per device.') add_arg('class_dim', int, 1000, "Class number.")
parser.add_argument( add_arg('image_shape', str, "3,224,224", "input image size")
'--multi_batch_repeat', type=int, default=1, help='Batch merge repeats.') add_arg('model_save_dir', str, "output", "model save directory")
parser.add_argument( add_arg('with_mem_opt', bool, False, "Whether to use memory optimization or not.")
'--learning_rate', type=float, default=0.1, help='The learning rate.') add_arg('pretrained_model', str, None, "Whether to use pretrained model.")
parser.add_argument( add_arg('checkpoint', str, None, "Whether to resume checkpoint.")
'--pass_num', type=int, default=90, help='The number of passes.') add_arg('lr', float, 0.1, "set learning rate.")
parser.add_argument( add_arg('lr_strategy', str, "piecewise_decay", "Set the learning rate decay strategy.")
'--data_format', add_arg('model', str, "DistResNet", "Set the network to use.")
type=str, add_arg('enable_ce', bool, False, "If set True, enable continuous evaluation job.")
default='NCHW', add_arg('data_dir', str, "./data/ILSVRC2012", "The ImageNet dataset root dir.")
choices=['NCHW', 'NHWC'], add_arg('model_category', str, "models", "Whether to use models_name or not, valid value:'models','models_name'" )
help='The data data_format, now only support NCHW.') add_arg('fp16', bool, False, "Enable half precision training with fp16." )
parser.add_argument( add_arg('scale_loss', float, 1.0, "Scale loss for fp16." )
'--device', # for distributed
type=str, add_arg('update_method', str, "local", "Can be local, pserver, nccl2.")
default='GPU', add_arg('multi_batch_repeat', int, 1, "Batch merge repeats.")
choices=['CPU', 'GPU'], add_arg('start_test_pass', int, 0, "Start test after x passes.")
help='The device type.') add_arg('num_threads', int, 8, "Use num_threads to run the fluid program.")
parser.add_argument( add_arg('split_var', bool, True, "Split params on pserver.")
'--gpus', add_arg('async_mode', bool, False, "Async distributed training, only for pserver mode.")
type=int, add_arg('reduce_strategy', str, "allreduce", "Choose from reduce or allreduce.")
default=1, # yapf: enable
help='If gpus > 1, will use ParallelExecutor to run, else use Executor.')
parser.add_argument(
'--cpus',
type=int,
default=1,
help='If cpus > 1, will set ParallelExecutor to use multiple threads.')
parser.add_argument(
'--no_test',
action='store_true',
help='If set, do not test the testset during training.')
parser.add_argument(
'--memory_optimize',
action='store_true',
help='If set, optimize runtime memory before start.')
parser.add_argument(
'--update_method',
type=str,
default='local',
choices=['local', 'pserver', 'nccl2'],
help='Choose parameter update method, can be local, pserver, nccl2.')
parser.add_argument(
'--no_split_var',
action='store_true',
default=False,
help='Whether split variables into blocks when update_method is pserver')
parser.add_argument(
'--async_mode',
action='store_true',
default=False,
help='Whether start pserver in async mode to support ASGD')
parser.add_argument(
'--reduce_strategy',
type=str,
choices=['reduce', 'all_reduce'],
default='all_reduce',
help='Specify the reduce strategy, can be reduce, all_reduce')
parser.add_argument(
'--data_dir',
type=str,
default="../data/ILSVRC2012",
help="The ImageNet dataset root dir."
)
args = parser.parse_args() args = parser.parse_args()
return args return args
def get_model(args, is_train, main_prog, startup_prog): def get_device_num():
pyreader = None if os.getenv("CPU_NUM"):
class_dim = 1000 return int(os.getenv("CPU_NUM"))
if args.data_format == 'NCHW': visible_device = os.getenv('CUDA_VISIBLE_DEVICES')
dshape = [3, 224, 224] if visible_device:
device_num = len(visible_device.split(','))
else: else:
dshape = [224, 224, 3] device_num = subprocess.check_output(['nvidia-smi', '-L']).decode().count('\n')
return device_num
def prepare_reader(is_train, pyreader, args):
if is_train: if is_train:
reader = train(data_dir=args.data_dir) reader = train(data_dir=args.data_dir)
else: else:
reader = val(data_dir=args.data_dir) reader = val(data_dir=args.data_dir)
if is_train:
bs = args.batch_size / get_device_num()
else:
bs = 16
pyreader.decorate_paddle_reader(
paddle.batch(
reader,
batch_size=bs))
def build_program(is_train, main_prog, startup_prog, args):
pyreader = None
class_dim = args.class_dim
image_shape = [int(m) for m in args.image_shape.split(",")]
trainer_count = int(os.getenv("PADDLE_TRAINERS", "1")) trainer_count = args.dist_env["num_trainers"]
with fluid.program_guard(main_prog, startup_prog): with fluid.program_guard(main_prog, startup_prog):
pyreader = fluid.layers.py_reader(
capacity=16,
shapes=([-1] + image_shape, (-1, 1)),
dtypes=('float32', 'int64'),
name="train_reader" if is_train else "test_reader",
use_double_buffer=True)
with fluid.unique_name.guard(): with fluid.unique_name.guard():
pyreader = fluid.layers.py_reader( image, label = fluid.layers.read_file(pyreader)
capacity=args.batch_size * args.gpus, if args.fp16:
shapes=([-1] + dshape, (-1, 1)), image = fluid.layers.cast(image, "float16")
dtypes=('float32', 'int64'),
name="train_reader" if is_train else "test_reader",
use_double_buffer=True)
input, label = fluid.layers.read_file(pyreader)
model_def = models.__dict__[args.model](layers=50, is_train=is_train) model_def = models.__dict__[args.model](layers=50, is_train=is_train)
predict = model_def.net(input, class_dim=class_dim) predict = model_def.net(image, class_dim=class_dim)
cost, pred = fluid.layers.softmax_with_cross_entropy(predict, label, return_softmax=True)
if args.scale_loss > 1:
avg_cost = fluid.layers.mean(x=cost) * float(args.scale_loss)
else:
avg_cost = fluid.layers.mean(x=cost)
cost = fluid.layers.cross_entropy(input=predict, label=label) batch_acc1 = fluid.layers.accuracy(input=pred, label=label, k=1)
avg_cost = fluid.layers.mean(x=cost) batch_acc5 = fluid.layers.accuracy(input=pred, label=label, k=5)
batch_acc1 = fluid.layers.accuracy(input=predict, label=label, k=1)
batch_acc5 = fluid.layers.accuracy(input=predict, label=label, k=5)
optimizer = None optimizer = None
if is_train: if is_train:
start_lr = args.learning_rate start_lr = args.lr
# n * worker * repeat # n * worker * repeat
end_lr = args.learning_rate * trainer_count * args.multi_batch_repeat end_lr = args.lr * trainer_count * args.multi_batch_repeat
total_images = 1281167 / trainer_count total_images = args.total_images / trainer_count
step = int(total_images / (args.batch_size * args.gpus * args.multi_batch_repeat) + 1) step = int(total_images / (args.batch_size * args.multi_batch_repeat) + 1)
warmup_steps = step * 5 # warmup 5 passes warmup_steps = step * 5 # warmup 5 passes
epochs = [30, 60, 80] epochs = [30, 60, 80]
bd = [step * e for e in epochs] bd = [step * e for e in epochs]
base_lr = end_lr base_lr = end_lr
lr = [] lr = []
lr = [base_lr * (0.1**i) for i in range(len(bd) + 1)] lr = [base_lr * (0.1**i) for i in range(len(bd) + 1)]
print("start lr: %s, end lr: %s, decay boundaries: %s" % (
start_lr,
end_lr,
bd
))
# NOTE: we put weight decay in layers config, and remove # NOTE: we put weight decay in layers config, and remove
# weight decay on bn layers, so don't add weight decay in # weight decay on bn layers, so don't add weight decay in
...@@ -159,151 +149,77 @@ def get_model(args, is_train, main_prog, startup_prog): ...@@ -159,151 +149,77 @@ def get_model(args, is_train, main_prog, startup_prog):
boundaries=bd, values=lr), boundaries=bd, values=lr),
warmup_steps, start_lr, end_lr), warmup_steps, start_lr, end_lr),
momentum=0.9) momentum=0.9)
optimizer.minimize(avg_cost) if args.fp16:
params_grads = optimizer.backward(avg_cost)
master_params_grads = utils.create_master_params_grads(
params_grads, main_prog, startup_prog, args.scale_loss)
optimizer.apply_gradients(master_params_grads)
utils.master_param_to_train_param(master_params_grads, params_grads, main_prog)
else:
optimizer.minimize(avg_cost)
batched_reader = None # prepare reader for current program
pyreader.decorate_paddle_reader( prepare_reader(is_train, pyreader, args)
paddle.batch(
reader, return pyreader, avg_cost, batch_acc1, batch_acc5
batch_size=args.batch_size))
return avg_cost, optimizer, [batch_acc1, def test_single(exe, test_prog, args, pyreader, fetch_list):
batch_acc5], batched_reader, pyreader acc1 = fluid.metrics.Accuracy()
acc5 = fluid.metrics.Accuracy()
def append_nccl2_prepare(trainer_id, startup_prog): test_losses = []
trainer_id = int(os.getenv("PADDLE_TRAINER_ID")) pyreader.start()
port = os.getenv("PADDLE_PSERVER_PORT")
worker_ips = os.getenv("PADDLE_TRAINER_IPS")
worker_endpoints = []
for ip in worker_ips.split(","):
worker_endpoints.append(':'.join([ip, port]))
current_endpoint = os.getenv("PADDLE_CURRENT_IP") + ":" + port
num_trainers = len(worker_endpoints)
config = fluid.DistributeTranspilerConfig()
config.mode = "nccl2"
t = fluid.DistributeTranspiler(config=config)
t.transpile(trainer_id, trainers=','.join(worker_endpoints),
current_endpoint=current_endpoint,
startup_program=startup_prog)
return num_trainers, trainer_id
def dist_transpile(trainer_id, args, train_prog, startup_prog):
port = os.getenv("PADDLE_PSERVER_PORT", "6174")
pserver_ips = os.getenv("PADDLE_PSERVER_IPS", "")
eplist = []
for ip in pserver_ips.split(","):
eplist.append(':'.join([ip, port]))
pserver_endpoints = ",".join(eplist)
trainers = int(os.getenv("PADDLE_TRAINERS"))
current_endpoint = os.getenv("PADDLE_CURRENT_IP", "") + ":" + port
training_role = os.getenv("PADDLE_TRAINING_ROLE")
config = fluid.DistributeTranspilerConfig()
config.slice_var_up = not args.no_split_var
t = fluid.DistributeTranspiler(config=config)
t.transpile(
trainer_id,
program=train_prog,
pservers=pserver_endpoints,
trainers=trainers,
sync_mode=not args.async_mode,
startup_program=startup_prog)
if training_role == "PSERVER":
pserver_program = t.get_pserver_program(current_endpoint)
pserver_startup_program = t.get_startup_program(
current_endpoint, pserver_program, startup_program=startup_prog)
return pserver_program, pserver_startup_program
elif training_role == "TRAINER":
train_program = t.get_trainer_program()
return train_program, startup_prog
else:
raise ValueError(
'PADDLE_TRAINING_ROLE environment variable must be either TRAINER or PSERVER'
)
def append_bn_repeat_init_op(main_prog, startup_prog, num_repeats):
repeat_vars = set()
for op in main_prog.global_block().ops:
if op.type == "batch_norm":
repeat_vars.add(op.input("Mean")[0])
repeat_vars.add(op.input("Variance")[0])
for i in range(num_repeats):
for op in startup_prog.global_block().ops:
if op.type == "fill_constant":
for oname in op.output_arg_names:
if oname in repeat_vars:
var = startup_prog.global_block().var(oname)
repeat_var_name = "%s.repeat.%d" % (oname, i)
repeat_var = startup_prog.global_block().create_var(
name=repeat_var_name,
type=var.type,
dtype=var.dtype,
shape=var.shape,
persistable=var.persistable
)
main_prog.global_block()._clone_variable(repeat_var)
startup_prog.global_block().append_op(
type="fill_constant",
inputs={},
outputs={"Out": repeat_var},
attrs=op.all_attrs()
)
def copyback_repeat_bn_params(main_prog):
repeat_vars = set()
for op in main_prog.global_block().ops:
if op.type == "batch_norm":
repeat_vars.add(op.input("Mean")[0])
repeat_vars.add(op.input("Variance")[0])
for vname in repeat_vars:
real_var = fluid.global_scope().find_var("%s.repeat.0" % vname).get_tensor()
orig_var = fluid.global_scope().find_var(vname).get_tensor()
orig_var.set(np.array(real_var), fluid.CUDAPlace(0)) # test on GPU0
def test_single(exe, test_args, args, test_prog):
acc_evaluators = []
for i in xrange(len(test_args[2])):
acc_evaluators.append(fluid.metrics.Accuracy())
to_fetch = [v.name for v in test_args[2]]
test_args[4].start()
while True: while True:
try: try:
acc_rets = exe.run(program=test_prog, fetch_list=to_fetch) acc_rets = exe.run(program=test_prog, fetch_list=fetch_list)
for i, e in enumerate(acc_evaluators): test_losses.append(acc_rets[0])
e.update( acc1.update(value=np.array(acc_rets[1]), weight=args.batch_size)
value=np.array(acc_rets[i]), weight=args.batch_size) acc5.update(value=np.array(acc_rets[2]), weight=args.batch_size)
except fluid.core.EOFException as eof: except fluid.core.EOFException:
test_args[4].reset() pyreader.reset()
break break
test_avg_loss = np.mean(np.array(test_losses))
return test_avg_loss, np.mean(acc1.eval()), np.mean(acc5.eval())
return [e.eval() for e in acc_evaluators] def run_pserver(train_prog, startup_prog):
server_exe = fluid.Executor(fluid.CPUPlace())
server_exe.run(startup_prog)
server_exe.run(train_prog)
def train_parallel(args):
train_prog = fluid.Program()
test_prog = fluid.Program()
startup_prog = fluid.Program()
def train_parallel(train_args, test_args, args, train_prog, test_prog, train_pyreader, train_cost, train_acc1, train_acc5 = build_program(True, train_prog, startup_prog, args)
startup_prog, num_trainers, trainer_id): test_pyreader, test_cost, test_acc1, test_acc5 = build_program(False, test_prog, startup_prog, args)
over_all_start = time.time()
place = core.CPUPlace() if args.device == 'CPU' else core.CUDAPlace(0) if args.update_method == "pserver":
train_prog, startup_prog = pserver_prepare(args, train_prog, startup_prog)
elif args.update_method == "nccl2":
nccl2_prepare(args, startup_prog)
if args.update_method == "nccl2" and trainer_id == 0: if args.dist_env["training_role"] == "PSERVER":
#FIXME(typhoonzero): wait other trainer to start listening run_pserver(train_prog, startup_prog)
time.sleep(30) exit(0)
if args.use_gpu:
# NOTE: for multi process mode: one process per GPU device.
gpu_id = 0
if os.getenv("FLAGS_selected_gpus"):
gpu_id = int(os.getenv("FLAGS_selected_gpus"))
place = core.CUDAPlace(gpu_id) if args.use_gpu else core.CPUPlace()
startup_exe = fluid.Executor(place) startup_exe = fluid.Executor(place)
if args.multi_batch_repeat > 1: if args.multi_batch_repeat > 1:
append_bn_repeat_init_op(train_prog, startup_prog, args.multi_batch_repeat) append_bn_repeat_init_op(train_prog, startup_prog, args.multi_batch_repeat)
startup_exe.run(startup_prog) startup_exe.run(startup_prog)
strategy = fluid.ExecutionStrategy() strategy = fluid.ExecutionStrategy()
strategy.num_threads = args.cpus strategy.num_threads = args.num_threads
strategy.allow_op_delay = False
build_strategy = fluid.BuildStrategy() build_strategy = fluid.BuildStrategy()
if args.multi_batch_repeat > 1: if args.multi_batch_repeat > 1:
pass_builder = build_strategy._create_passes_from_strategy() pass_builder = build_strategy._finalize_strategy_and_create_passes()
mypass = pass_builder.insert_pass( mypass = pass_builder.insert_pass(
len(pass_builder.all_passes()) - 2, "multi_batch_merge_pass") len(pass_builder.all_passes()) - 2, "multi_batch_merge_pass")
mypass.set_int("num_repeats", args.multi_batch_repeat) mypass.set_int("num_repeats", args.multi_batch_repeat)
...@@ -314,73 +230,65 @@ def train_parallel(train_args, test_args, args, train_prog, test_prog, ...@@ -314,73 +230,65 @@ def train_parallel(train_args, test_args, args, train_prog, test_prog,
build_strategy.reduce_strategy = fluid.BuildStrategy( build_strategy.reduce_strategy = fluid.BuildStrategy(
).ReduceStrategy.AllReduce ).ReduceStrategy.AllReduce
avg_loss = train_args[0] if args.update_method == "pserver" or args.update_method == "local":
if args.update_method == "pserver":
# parameter server mode distributed training, merge # parameter server mode distributed training, merge
# gradients on local server, do not initialize # gradients on local server, do not initialize
# ParallelExecutor with multi server all-reduce mode. # ParallelExecutor with multi server all-reduce mode.
num_trainers = 1 num_trainers = 1
trainer_id = 0 trainer_id = 0
else:
num_trainers = args.dist_env["num_trainers"]
trainer_id = args.dist_env["trainer_id"]
exe = fluid.ParallelExecutor( exe = fluid.ParallelExecutor(
True, True,
avg_loss.name, train_cost.name,
main_program=train_prog, main_program=train_prog,
exec_strategy=strategy, exec_strategy=strategy,
build_strategy=build_strategy, build_strategy=build_strategy,
num_trainers=num_trainers, num_trainers=num_trainers,
trainer_id=trainer_id) trainer_id=trainer_id)
pyreader = train_args[4] over_all_start = time.time()
for pass_id in range(args.pass_num): fetch_list = [train_cost.name, train_acc1.name, train_acc5.name]
for pass_id in range(args.num_epochs):
num_samples = 0 num_samples = 0
start_time = time.time() start_time = time.time()
batch_id = 0 batch_id = 1
pyreader.start() train_pyreader.start()
while True: while True:
fetch_list = [avg_loss.name]
acc_name_list = [v.name for v in train_args[2]]
fetch_list.extend(acc_name_list)
try: try:
if batch_id % 30 == 0: if batch_id % 30 == 0:
fetch_ret = exe.run(fetch_list) fetch_ret = exe.run(fetch_list)
fetched_data = [np.mean(np.array(d)) for d in fetch_ret]
print("Pass %d, batch %d, loss %s, acc1: %s, acc5: %s, avg batch time %.4f" %
(pass_id, batch_id, fetched_data[0], fetched_data[1],
fetched_data[2], (time.time()-start_time) / batch_id))
else: else:
fetch_ret = exe.run([]) fetch_ret = exe.run([])
except fluid.core.EOFException as eof: except fluid.core.EOFException:
break break
except fluid.core.EnforceNotMet as ex: except fluid.core.EnforceNotMet:
traceback.print_exc() traceback.print_exc()
break break
num_samples += args.batch_size * args.gpus num_samples += args.batch_size
if batch_id % 30 == 0:
fetched_data = [np.mean(np.array(d)) for d in fetch_ret]
print("Pass %d, batch %d, loss %s, accucacys: %s" %
(pass_id, batch_id, fetched_data[0], fetched_data[1:]))
batch_id += 1 batch_id += 1
print_train_time(start_time, time.time(), num_samples) print_train_time(start_time, time.time(), num_samples)
pyreader.reset() train_pyreader.reset()
if not args.no_test and test_args[2]: if pass_id > args.start_test_pass:
if args.multi_batch_repeat > 1: if args.multi_batch_repeat > 1:
copyback_repeat_bn_params(train_prog) copyback_repeat_bn_params(train_prog)
test_ret = test_single(startup_exe, test_args, args, test_prog) test_fetch_list = [test_cost.name, test_acc1.name, test_acc5.name]
print("Pass: %d, Test Accuracy: %s\n" % test_ret = test_single(startup_exe, test_prog, args, test_pyreader,test_fetch_list)
(pass_id, [np.mean(np.array(v)) for v in test_ret])) print("Pass: %d, Test Loss %s, test acc1: %s, test acc5: %s\n" %
(pass_id, test_ret[0], test_ret[1], test_ret[2]))
startup_exe.close() startup_exe.close()
print("total train time: ", time.time() - over_all_start) print("total train time: ", time.time() - over_all_start)
def print_arguments(args):
print('----------- Configuration Arguments -----------')
for arg, value in sorted(six.iteritems(vars(args))):
print('%s: %s' % (arg, value))
print('------------------------------------------------')
def print_train_time(start_time, end_time, num_samples): def print_train_time(start_time, end_time, num_samples):
train_elapsed = end_time - start_time train_elapsed = end_time - start_time
examples_per_sec = num_samples / train_elapsed examples_per_sec = num_samples / train_elapsed
...@@ -400,47 +308,8 @@ def main(): ...@@ -400,47 +308,8 @@ def main():
args = parse_args() args = parse_args()
print_arguments(args) print_arguments(args)
print_paddle_envs() print_paddle_envs()
args.dist_env = dist_env()
# the unique trainer id, starting from 0, needed by trainer train_parallel(args)
# only
num_trainers, trainer_id = (
1, int(os.getenv("PADDLE_TRAINER_ID", "0")))
train_prog = fluid.Program()
test_prog = fluid.Program()
startup_prog = fluid.Program()
train_args = list(get_model(args, True, train_prog, startup_prog))
test_args = list(get_model(args, False, test_prog, startup_prog))
all_args = [train_args, test_args, args]
if args.update_method == "pserver":
train_prog, startup_prog = dist_transpile(trainer_id, args, train_prog,
startup_prog)
if not train_prog:
raise Exception(
"Must configure correct environments to run dist train.")
all_args.extend([train_prog, test_prog, startup_prog])
if os.getenv("PADDLE_TRAINING_ROLE") == "TRAINER":
all_args.extend([num_trainers, trainer_id])
train_parallel(*all_args)
elif os.getenv("PADDLE_TRAINING_ROLE") == "PSERVER":
# start pserver with Executor
server_exe = fluid.Executor(fluid.CPUPlace())
server_exe.run(startup_prog)
server_exe.run(train_prog)
exit(0)
# for other update methods, use default programs
all_args.extend([train_prog, test_prog, startup_prog])
if args.update_method == "nccl2":
num_trainers, trainer_id = append_nccl2_prepare(
trainer_id, startup_prog)
all_args.extend([num_trainers, trainer_id])
train_parallel(*all_args)
if __name__ == "__main__": if __name__ == "__main__":
main() main()
import os
import paddle.fluid as fluid
def nccl2_prepare(args, startup_prog):
config = fluid.DistributeTranspilerConfig()
config.mode = "nccl2"
t = fluid.DistributeTranspiler(config=config)
envs = args.dist_env
t.transpile(envs["trainer_id"],
trainers=','.join(envs["trainer_endpoints"]),
current_endpoint=envs["current_endpoint"],
startup_program=startup_prog)
def pserver_prepare(args, train_prog, startup_prog):
config = fluid.DistributeTranspilerConfig()
config.slice_var_up = args.split_var
t = fluid.DistributeTranspiler(config=config)
envs = args.dist_env
training_role = envs["training_role"]
t.transpile(
envs["trainer_id"],
program=train_prog,
pservers=envs["pserver_endpoints"],
trainers=envs["num_trainers"],
sync_mode=not args.async_mode,
startup_program=startup_prog)
if training_role == "PSERVER":
pserver_program = t.get_pserver_program(envs["current_endpoint"])
pserver_startup_program = t.get_startup_program(
envs["current_endpoint"], pserver_program, startup_program=startup_prog)
return pserver_program, pserver_startup_program
elif training_role == "TRAINER":
train_program = t.get_trainer_program()
return train_program, startup_prog
else:
raise ValueError(
'PADDLE_TRAINING_ROLE environment variable must be either TRAINER or PSERVER'
)
import os
def dist_env():
"""
Return a dict of all variable that distributed training may use.
NOTE: you may rewrite this function to suit your cluster environments.
"""
trainer_id = int(os.getenv("PADDLE_TRAINER_ID", "0"))
num_trainers = 1
training_role = os.getenv("PADDLE_TRAINING_ROLE", "TRAINER")
assert(training_role == "PSERVER" or training_role == "TRAINER")
# - PADDLE_TRAINER_ENDPOINTS means nccl2 mode.
# - PADDLE_PSERVER_ENDPOINTS means pserver mode.
# - PADDLE_CURRENT_ENDPOINT means current process endpoint.
trainer_endpoints = os.getenv("PADDLE_TRAINER_ENDPOINTS")
pserver_endpoints = os.getenv("PADDLE_PSERVER_ENDPOINTS")
current_endpoint = os.getenv("PADDLE_CURRENT_ENDPOINT")
if trainer_endpoints:
trainer_endpoints = trainer_endpoints.split(",")
num_trainers = len(trainer_endpoints)
elif pserver_endpoints:
num_trainers = int(os.getenv("PADDLE_TRAINERS_NUM"))
return {
"trainer_id": trainer_id,
"num_trainers": num_trainers,
"current_endpoint": current_endpoint,
"training_role": training_role,
"pserver_endpoints": pserver_endpoints,
"trainer_endpoints": trainer_endpoints
}
#!/bin/bash
# Test using 4 GPUs
export CUDA_VISIBLE_DEVICES="0,1,2,3"
export MODEL="DistResNet"
export PADDLE_TRAINER_ENDPOINTS="127.0.0.1:7160,127.0.0.1:7161,127.0.0.1:7162,127.0.0.1:7163"
# PADDLE_TRAINERS_NUM is used only for reader when nccl2 mode
export PADDLE_TRAINERS_NUM="4"
mkdir -p logs
for i in {0..3}
do
PADDLE_TRAINING_ROLE="TRAINER" \
PADDLE_CURRENT_ENDPOINT="127.0.0.1:716${i}" \
PADDLE_TRAINER_ID="${i}" \
FLAGS_selected_gpus="${i}" \
python dist_train.py --model $MODEL --update_method nccl2 --batch_size 32 --fp16 1 --scale_loss 8 &> logs/tr$i.log &
done
#!/bin/bash
export MODEL="DistResNet"
export PADDLE_TRAINER_ENDPOINTS="127.0.0.1:7160,127.0.0.1:7161"
# PADDLE_TRAINERS_NUM is used only for reader when nccl2 mode
export PADDLE_TRAINERS_NUM="2"
mkdir -p logs
# NOTE: set NCCL_P2P_DISABLE so that can run nccl2 distribute train on one node.
PADDLE_TRAINING_ROLE="TRAINER" \
PADDLE_CURRENT_ENDPOINT="127.0.0.1:7160" \
PADDLE_TRAINER_ID="0" \
CUDA_VISIBLE_DEVICES="0" \
NCCL_P2P_DISABLE="1" \
python dist_train.py --model $MODEL --update_method nccl2 --batch_size 32 &> logs/tr0.log &
PADDLE_TRAINING_ROLE="TRAINER" \
PADDLE_CURRENT_ENDPOINT="127.0.0.1:7161" \
PADDLE_TRAINER_ID="1" \
CUDA_VISIBLE_DEVICES="1" \
NCCL_P2P_DISABLE="1" \
python dist_train.py --model $MODEL --update_method nccl2 --batch_size 32 &> logs/tr1.log &
#!/bin/bash
export MODEL="DistResNet"
export PADDLE_PSERVER_ENDPOINTS="127.0.0.1:7160,127.0.0.1:7161"
export PADDLE_TRAINERS_NUM="2"
mkdir -p logs
PADDLE_TRAINING_ROLE="PSERVER" \
PADDLE_CURRENT_ENDPOINT="127.0.0.1:7160" \
python dist_train.py --model $MODEL --update_method pserver --batch_size 32 &> logs/ps0.log &
PADDLE_TRAINING_ROLE="PSERVER" \
PADDLE_CURRENT_ENDPOINT="127.0.0.1:7161" \
python dist_train.py --model $MODEL --update_method pserver --batch_size 32 &> logs/ps1.log &
PADDLE_TRAINING_ROLE="TRAINER" \
PADDLE_CURRENT_ENDPOINT="127.0.0.1:7160" \
PADDLE_TRAINER_ID="0" \
CUDA_VISIBLE_DEVICES="0" \
python dist_train.py --model $MODEL --update_method pserver --batch_size 32 &> logs/tr0.log &
PADDLE_TRAINING_ROLE="TRAINER" \
PADDLE_CURRENT_ENDPOINT="127.0.0.1:7161" \
PADDLE_TRAINER_ID="1" \
CUDA_VISIBLE_DEVICES="1" \
python dist_train.py --model $MODEL --update_method pserver --batch_size 32 &> logs/tr1.log &
...@@ -14,8 +14,9 @@ train_parameters = { ...@@ -14,8 +14,9 @@ train_parameters = {
"learning_strategy": { "learning_strategy": {
"name": "piecewise_decay", "name": "piecewise_decay",
"batch_size": 256, "batch_size": 256,
"epochs": [30, 60, 90], "epochs": [30, 60, 80],
"steps": [0.1, 0.01, 0.001, 0.0001] "steps": [0.1, 0.01, 0.001, 0.0001],
"warmup_passes": 5
} }
} }
...@@ -118,3 +119,4 @@ class DistResNet(): ...@@ -118,3 +119,4 @@ class DistResNet():
short = self.shortcut(input, num_filters * 4, stride) short = self.shortcut(input, num_filters * 4, stride)
return fluid.layers.elementwise_add(x=short, y=conv2, act='relu') return fluid.layers.elementwise_add(x=short, y=conv2, act='relu')
...@@ -139,7 +139,7 @@ def _reader_creator(file_list, ...@@ -139,7 +139,7 @@ def _reader_creator(file_list,
if mode == 'train' and os.getenv('PADDLE_TRAINING_ROLE'): if mode == 'train' and os.getenv('PADDLE_TRAINING_ROLE'):
# distributed mode if the env var `PADDLE_TRAINING_ROLE` exits # distributed mode if the env var `PADDLE_TRAINING_ROLE` exits
trainer_id = int(os.getenv("PADDLE_TRAINER_ID", "0")) trainer_id = int(os.getenv("PADDLE_TRAINER_ID", "0"))
trainer_count = int(os.getenv("PADDLE_TRAINERS", "1")) trainer_count = int(os.getenv("PADDLE_TRAINERS_NUM", "1"))
per_node_lines = len(full_lines) // trainer_count per_node_lines = len(full_lines) // trainer_count
lines = full_lines[trainer_id * per_node_lines:(trainer_id + 1) lines = full_lines[trainer_id * per_node_lines:(trainer_id + 1)
* per_node_lines] * per_node_lines]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册