未验证 提交 67edb3e1 编写于 作者: Y Yibing Liu 提交者: GitHub

Merge pull request #174 from chengduoZH/benchmark

[NOT MERGE]Support multi process training for BERT
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import paddle.fluid as fluid
def nccl2_prepare(trainer_id, startup_prog, main_prog):
config = fluid.DistributeTranspilerConfig()
config.mode = "nccl2"
t = fluid.DistributeTranspiler(config=config)
t.transpile(trainer_id,
trainers=os.environ.get('PADDLE_TRAINER_ENDPOINTS'),
current_endpoint=os.environ.get('PADDLE_CURRENT_ENDPOINT'),
startup_program=startup_prog,
program=main_prog)
def prepare_for_multi_process(exe, build_strategy, train_prog):
# prepare for multi-process
trainer_id = int(os.environ.get('PADDLE_TRAINER_ID', 0))
num_trainers = int(os.environ.get('PADDLE_TRAINERS_NUM', 1))
if num_trainers < 2: return
print("PADDLE_TRAINERS_NUM", num_trainers)
print("PADDLE_TRAINER_ID", trainer_id)
build_strategy.num_trainers = num_trainers
build_strategy.trainer_id = trainer_id
# NOTE(zcd): use multi processes to train the model,
# and each process use one GPU card.
startup_prog = fluid.Program()
nccl2_prepare(trainer_id, startup_prog, train_prog)
# the startup_prog are run two times, but it doesn't matter.
exe.run(startup_prog)
...@@ -123,7 +123,8 @@ class DataProcessor(object): ...@@ -123,7 +123,8 @@ class DataProcessor(object):
phase='train', phase='train',
epoch=1, epoch=1,
dev_count=1, dev_count=1,
shuffle=True): shuffle=True,
shuffle_seed=None):
""" """
Generate data for train, dev or test. Generate data for train, dev or test.
...@@ -149,6 +150,8 @@ class DataProcessor(object): ...@@ -149,6 +150,8 @@ class DataProcessor(object):
def instance_reader(): def instance_reader():
for epoch_index in range(epoch): for epoch_index in range(epoch):
if shuffle: if shuffle:
if shuffle_seed is not None:
np.random.seed(shuffle_seed)
np.random.shuffle(examples) np.random.shuffle(examples)
if phase == 'train': if phase == 'train':
self.current_train_epoch = epoch_index self.current_train_epoch = epoch_index
......
...@@ -32,6 +32,9 @@ from model.classifier import create_model ...@@ -32,6 +32,9 @@ from model.classifier import create_model
from optimization import optimization from optimization import optimization
from utils.args import ArgumentGroup, print_arguments from utils.args import ArgumentGroup, print_arguments
from utils.init import init_pretraining_params, init_checkpoint from utils.init import init_pretraining_params, init_checkpoint
import dist_utils
num_trainers = int(os.environ.get('PADDLE_TRAINERS_NUM', 1))
# yapf: disable # yapf: disable
parser = argparse.ArgumentParser(__doc__) parser = argparse.ArgumentParser(__doc__)
...@@ -76,6 +79,7 @@ data_g.add_arg("random_seed", int, 0, "Random seed.") ...@@ -76,6 +79,7 @@ data_g.add_arg("random_seed", int, 0, "Random seed.")
run_type_g = ArgumentGroup(parser, "run_type", "running type options.") run_type_g = ArgumentGroup(parser, "run_type", "running type options.")
run_type_g.add_arg("use_cuda", bool, True, "If set, use GPU for training.") run_type_g.add_arg("use_cuda", bool, True, "If set, use GPU for training.")
run_type_g.add_arg("use_fast_executor", bool, False, "If set, use fast parallel executor (in experiment).") run_type_g.add_arg("use_fast_executor", bool, False, "If set, use fast parallel executor (in experiment).")
run_type_g.add_arg("shuffle", bool, True, "")
run_type_g.add_arg("num_iteration_per_drop_scope", int, 1, "Ihe iteration intervals to clean up temporary variables.") run_type_g.add_arg("num_iteration_per_drop_scope", int, 1, "Ihe iteration intervals to clean up temporary variables.")
run_type_g.add_arg("task_name", str, None, run_type_g.add_arg("task_name", str, None,
"The name of task to perform fine-tuning, should be in {'xnli', 'mnli', 'cola', 'mrpc'}.") "The name of task to perform fine-tuning, should be in {'xnli', 'mnli', 'cola', 'mrpc'}.")
...@@ -106,6 +110,15 @@ def evaluate(exe, test_program, test_pyreader, fetch_list, eval_phase): ...@@ -106,6 +110,15 @@ def evaluate(exe, test_program, test_pyreader, fetch_list, eval_phase):
(eval_phase, np.sum(total_cost) / np.sum(total_num_seqs), (eval_phase, np.sum(total_cost) / np.sum(total_num_seqs),
np.sum(total_acc) / np.sum(total_num_seqs), time_end - time_begin)) np.sum(total_acc) / np.sum(total_num_seqs), time_end - time_begin))
def get_device_num():
# NOTE(zcd): for multi-processe training, each process use one GPU card.
if num_trainers > 1 : return 1
visible_device = os.environ.get('CUDA_VISIBLE_DEVICES', None)
if visible_device:
device_num = len(visible_device.split(','))
else:
device_num = subprocess.check_output(['nvidia-smi','-L']).decode().count('\n')
return device_num
def main(args): def main(args):
bert_config = BertConfig(args.bert_config_path) bert_config = BertConfig(args.bert_config_path)
...@@ -113,7 +126,7 @@ def main(args): ...@@ -113,7 +126,7 @@ def main(args):
if args.use_cuda: if args.use_cuda:
place = fluid.CUDAPlace(int(os.getenv('FLAGS_selected_gpus', '0'))) place = fluid.CUDAPlace(int(os.getenv('FLAGS_selected_gpus', '0')))
dev_count = fluid.core.get_cuda_device_count() dev_count = get_device_num()
else: else:
place = fluid.CPUPlace() place = fluid.CPUPlace()
dev_count = int(os.environ.get('CPU_NUM', multiprocessing.cpu_count())) dev_count = int(os.environ.get('CPU_NUM', multiprocessing.cpu_count()))
...@@ -139,17 +152,24 @@ def main(args): ...@@ -139,17 +152,24 @@ def main(args):
raise ValueError("For args `do_train`, `do_val` and `do_test`, at " raise ValueError("For args `do_train`, `do_val` and `do_test`, at "
"least one of them must be True.") "least one of them must be True.")
train_program = fluid.Program()
startup_prog = fluid.Program() startup_prog = fluid.Program()
if args.random_seed is not None: if args.random_seed is not None:
startup_prog.random_seed = args.random_seed startup_prog.random_seed = args.random_seed
train_program.random_seed = args.random_seed
if args.do_train: if args.do_train:
# NOTE: If num_trainers > 1, the shuffle_seed must be set, because
# the order of batch data generated by reader
# must be the same in the respective processes.
shuffle_seed = 1 if num_trainers > 1 else None
train_data_generator = processor.data_generator( train_data_generator = processor.data_generator(
batch_size=args.batch_size, batch_size=args.batch_size,
phase='train', phase='train',
epoch=args.epoch, epoch=args.epoch,
dev_count=dev_count, dev_count=dev_count,
shuffle=True) shuffle=args.shuffle,
shuffle_seed=shuffle_seed)
num_train_examples = processor.get_num_examples(phase='train') num_train_examples = processor.get_num_examples(phase='train')
...@@ -165,8 +185,6 @@ def main(args): ...@@ -165,8 +185,6 @@ def main(args):
print("Max train steps: %d" % max_train_steps) print("Max train steps: %d" % max_train_steps)
print("Num warmup steps: %d" % warmup_steps) print("Num warmup steps: %d" % warmup_steps)
train_program = fluid.Program()
with fluid.program_guard(train_program, startup_prog): with fluid.program_guard(train_program, startup_prog):
with fluid.unique_name.guard(): with fluid.unique_name.guard():
train_pyreader, loss, probs, accuracy, num_seqs = create_model( train_pyreader, loss, probs, accuracy, num_seqs = create_model(
...@@ -249,13 +267,21 @@ def main(args): ...@@ -249,13 +267,21 @@ def main(args):
exec_strategy.use_experimental_executor = args.use_fast_executor exec_strategy.use_experimental_executor = args.use_fast_executor
exec_strategy.num_threads = dev_count exec_strategy.num_threads = dev_count
exec_strategy.num_iteration_per_drop_scope = args.num_iteration_per_drop_scope exec_strategy.num_iteration_per_drop_scope = args.num_iteration_per_drop_scope
build_strategy = fluid.BuildStrategy()
if args.use_cuda and num_trainers > 1:
assert shuffle_seed is not None
dist_utils.prepare_for_multi_process(exe, build_strategy, train_program)
train_data_generator = fluid.contrib.reader.distributed_batch_reader(
train_data_generator)
train_exe = fluid.ParallelExecutor( train_exe = fluid.ParallelExecutor(
use_cuda=args.use_cuda, use_cuda=args.use_cuda,
loss_name=loss.name, loss_name=loss.name,
exec_strategy=exec_strategy, exec_strategy=exec_strategy,
build_strategy = build_strategy,
main_program=train_program) main_program=train_program)
train_pyreader.decorate_tensor_provider(train_data_generator) train_pyreader.decorate_tensor_provider(train_data_generator)
else: else:
train_exe = None train_exe = None
...@@ -271,9 +297,10 @@ def main(args): ...@@ -271,9 +297,10 @@ def main(args):
steps = 0 steps = 0
total_cost, total_acc, total_num_seqs = [], [], [] total_cost, total_acc, total_num_seqs = [], [], []
time_begin = time.time() time_begin = time.time()
throughput = []
while True: while True:
try: try:
steps += 1 # steps += 1
if steps % args.skip_steps == 0: if steps % args.skip_steps == 0:
if warmup_steps <= 0: if warmup_steps <= 0:
fetch_list = [loss.name, accuracy.name, num_seqs.name] fetch_list = [loss.name, accuracy.name, num_seqs.name]
...@@ -309,21 +336,29 @@ def main(args): ...@@ -309,21 +336,29 @@ def main(args):
) )
time_end = time.time() time_end = time.time()
used_time = time_end - time_begin used_time = time_end - time_begin
print("epoch: %d, progress: %d/%d, step: %d, ave loss: %f, "
"ave acc: %f, speed: %f steps/s" % log_record = "epoch: {}, progress: {}/{}, step: {}, ave loss: {}, ave acc: {}".format(
(current_epoch, current_example, num_train_examples, current_epoch, current_example, num_train_examples,
steps, np.sum(total_cost) / np.sum(total_num_seqs), steps, np.sum(total_cost) / np.sum(total_num_seqs),
np.sum(total_acc) / np.sum(total_num_seqs), np.sum(total_acc) / np.sum(total_num_seqs))
args.skip_steps / used_time)) if steps > 0 :
throughput.append( args.skip_steps / used_time)
log_record = log_record + ", speed: %f steps/s" % (args.skip_steps / used_time)
print(log_record)
else:
print(log_record)
total_cost, total_acc, total_num_seqs = [], [], [] total_cost, total_acc, total_num_seqs = [], [], []
time_begin = time.time() time_begin = time.time()
steps += 1
if steps % args.save_steps == 0: if steps % args.save_steps == 0:
save_path = os.path.join(args.checkpoints, save_path = os.path.join(args.checkpoints,
"step_" + str(steps)) "step_" + str(steps))
fluid.io.save_persistables(exe, save_path, train_program) fluid.io.save_persistables(exe, save_path, train_program)
if steps % args.validation_steps == 0: if steps % args.validation_steps == 0:
print("Average throughtput: %s" % (np.average(throughput)))
throughput = []
# evaluate dev set # evaluate dev set
if args.do_val: if args.do_val:
test_pyreader.decorate_tensor_provider( test_pyreader.decorate_tensor_provider(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册