未验证 提交 7f4274fb 编写于 作者: X xuezhong 提交者: GitHub

Merge pull request #1018 from xuezhong/transformer_merge

distributed trainning for transformer
...@@ -110,7 +110,7 @@ python -u train.py \ ...@@ -110,7 +110,7 @@ python -u train.py \
``` ```
有关这些参数更详细信息的还请参考 `config.py` 中的注释说明。 有关这些参数更详细信息的还请参考 `config.py` 中的注释说明。
训练时默认使用所有 GPU,可以通过 `CUDA_VISIBLE_DEVICES` 环境变量来设置使用的 GPU 数目。在训练过程中,每个 epoch 结束后将保存模型到参数 `model_dir` 指定的目录,每个 iteration 将打印如下的日志到标准输出: 训练时默认使用所有 GPU,可以通过 `CUDA_VISIBLE_DEVICES` 环境变量来设置使用的 GPU 数目。也可以只使用CPU训练(通过参数--divice CPU),训练速度相对较慢。在训练过程中,每个 epoch 结束后将保存模型到参数 `model_dir` 指定的目录,每个 iteration 将打印如下的日志到标准输出:
```txt ```txt
epoch: 0, batch: 0, sum loss: 258793.343750, avg loss: 11.069005, ppl: 64151.644531 epoch: 0, batch: 0, sum loss: 258793.343750, avg loss: 11.069005, ppl: 64151.644531
epoch: 0, batch: 1, sum loss: 256140.718750, avg loss: 11.059616, ppl: 63552.148438 epoch: 0, batch: 1, sum loss: 256140.718750, avg loss: 11.059616, ppl: 63552.148438
...@@ -154,9 +154,82 @@ perl multi-bleu.perl data/newstest2013.tok.de < predict.tok.txt ...@@ -154,9 +154,82 @@ perl multi-bleu.perl data/newstest2013.tok.de < predict.tok.txt
``` ```
BLEU = 25.08, 58.3/31.5/19.6/12.6 (BP=0.966, ratio=0.967, hyp_len=61321, ref_len=63412) BLEU = 25.08, 58.3/31.5/19.6/12.6 (BP=0.966, ratio=0.967, hyp_len=61321, ref_len=63412)
``` ```
### 分布式训练
### 参考文献 transformer 模型支持同步或者异步的分布式训练。分布式的配置主要两个方面:
1 命令行配置
- `--local`,有两个取值,`True`表示单机训练,而`False`表示使用分布式训练。默认为单机训练模式。
- `--sync`,有两个取值,但只有当`--local`参数为False才会产生影响,其中`True`表示同步训练模式,`False`表示异步训练模式。默认为同步训练模式。
2 环境变量配置
在分布式训练模式下,会手动配置训练的trainer数量和pserver数量。在网络拓扑上,每一个trainer都会和每一个pserver相连,pserver作为服务端,而trainer作为客户端。下面分pserver和trainer说明具体的参数配置:
1) pserver配置
- `PADDLE_IS_LOCAL=[0|1]` 是否是分布式训练,`0`标识是分布式,`1`标识是单机
- `TRAINING_ROLE=PSERVER` 标识当前节点是pserver
- `POD_IP=ip` 设置当前pserver使用对外服务的地址
- `PADDLE_PORT=port` 设置当前pserver对外服务监听端口号,和`POD_IP`共同构成对外的唯一标识
- `PADDLE_TRAINERS_NUM=num` 设置pserver连接的trainer的数量
下面是配置的示例, 使用两个pserver, 192.168.2.2上的配置如下:
```
export PADDLE_PSERVERS=192.168.2.2,192.168.2.3
export POD_IP=192.168.2.2
export PADDLE_TRAINERS_NUM=2
export TRAINING_ROLE=PSERVER
export PADDLE_IS_LOCAL=0
export PADDLE_PORT=6177
```
192.168.2.3上的配置如下:
```
export PADDLE_PSERVERS=192.168.2.2,192.168.2.3
export POD_IP=192.168.2.3
export PADDLE_TRAINERS_NUM=2
export TRAINING_ROLE=PSERVER
export PADDLE_IS_LOCAL=0
export PADDLE_PORT=6177
```
2) trainer配置
- `PADDLE_IS_LOCAL=[0|1]` 是否是分布式训练,`0`标识是分布式,`1`标识是单机
- `TRAINING_ROLE=TRAINER` 标识当前节点是trainer
- `PADDLE_PSERVERS=[ip1,ip2,……]` 设置pserver的ip地址,用于告知trainer互联的pserver的ip, 使用`,`分割
- `PADDLE_TRAINER_ID=num` 设置当前节点的编号, 编号的取值范围为0到N-1的整数
- `PADDLE_PORT=port` 设置请求的pserver服务端口号
下面是配置的示例, 使用两个trainer, trainer 1上的配置如下:
```
export TRAINING_ROLE=TRAINER
export PADDLE_PSERVERS=192.168.2.2,192.168.2.3
export PADDLE_TRAINERS_NUM=2
export PADDLE_TRAINER_ID=0
export PADDLE_IS_LOCAL=0
export PADDLE_PORT=6177
```
trainer 2上的配置如下:
```
export TRAINING_ROLE=TRAINER
export PADDLE_PSERVERS=192.168.2.2,192.168.2.3
export PADDLE_TRAINERS_NUM=2
export PADDLE_TRAINER_ID=1
export PADDLE_IS_LOCAL=0
export PADDLE_PORT=6177
```
### 参考文献
1. Vaswani A, Shazeer N, Parmar N, et al. [Attention is all you need](http://papers.nips.cc/paper/7181-attention-is-all-you-need.pdf)[C]//Advances in Neural Information Processing Systems. 2017: 6000-6010. 1. Vaswani A, Shazeer N, Parmar N, et al. [Attention is all you need](http://papers.nips.cc/paper/7181-attention-is-all-you-need.pdf)[C]//Advances in Neural Information Processing Systems. 2017: 6000-6010.
2. He K, Zhang X, Ren S, et al. [Deep residual learning for image recognition](http://openaccess.thecvf.com/content_cvpr_2016/papers/He_Deep_Residual_Learning_CVPR_2016_paper.pdf)[C]//Proceedings of the IEEE conference on computer vision and pattern recognition. 2016: 770-778. 2. He K, Zhang X, Ren S, et al. [Deep residual learning for image recognition](http://openaccess.thecvf.com/content_cvpr_2016/papers/He_Deep_Residual_Learning_CVPR_2016_paper.pdf)[C]//Proceedings of the IEEE conference on computer vision and pattern recognition. 2016: 770-778.
3. Ba J L, Kiros J R, Hinton G E. [Layer normalization](https://arxiv.org/pdf/1607.06450.pdf)[J]. arXiv preprint arXiv:1607.06450, 2016. 3. Ba J L, Kiros J R, Hinton G E. [Layer normalization](https://arxiv.org/pdf/1607.06450.pdf)[J]. arXiv preprint arXiv:1607.06450, 2016.
......
...@@ -3,6 +3,7 @@ import time ...@@ -3,6 +3,7 @@ import time
import argparse import argparse
import ast import ast
import numpy as np import numpy as np
import multiprocessing
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
...@@ -80,6 +81,20 @@ def parse_args(): ...@@ -80,6 +81,20 @@ def parse_args():
help='See config.py for all options', help='See config.py for all options',
default=None, default=None,
nargs=argparse.REMAINDER) nargs=argparse.REMAINDER)
parser.add_argument(
'--local',
type=ast.literal_eval,
default=True,
help='Whether to run as local mode.')
parser.add_argument(
'--device',
type=str,
default='GPU',
choices=['CPU', 'GPU'],
help="The device type.")
parser.add_argument(
'--sync', type=ast.literal_eval, default=True, help="sync mode.")
args = parser.parse_args() args = parser.parse_args()
# Append args related to dict # Append args related to dict
src_dict = reader.DataReader.load_dict(args.src_vocab_fpath) src_dict = reader.DataReader.load_dict(args.src_vocab_fpath)
...@@ -247,34 +262,73 @@ def split_data(data, num_part): ...@@ -247,34 +262,73 @@ def split_data(data, num_part):
] ]
def train(args): def test_context(train_progm, avg_cost, train_exe, dev_count, data_input_names,
dev_count = fluid.core.get_cuda_device_count() util_input_names, sum_cost, token_num):
# Context to do validation.
test_program = train_progm.clone()
with fluid.program_guard(test_program):
test_program = fluid.io.get_inference_program([avg_cost])
sum_cost, avg_cost, predict, token_num = transformer( val_data = reader.DataReader(
ModelHyperParams.src_vocab_size, ModelHyperParams.trg_vocab_size, src_vocab_fpath=args.src_vocab_fpath,
ModelHyperParams.max_length + 1, ModelHyperParams.n_layer, trg_vocab_fpath=args.trg_vocab_fpath,
ModelHyperParams.n_head, ModelHyperParams.d_key, fpattern=args.val_file_pattern,
ModelHyperParams.d_value, ModelHyperParams.d_model, use_token_batch=args.use_token_batch,
ModelHyperParams.d_inner_hid, ModelHyperParams.dropout, batch_size=args.batch_size * (1 if args.use_token_batch else dev_count),
ModelHyperParams.weight_sharing, TrainTaskConfig.label_smooth_eps) pool_size=args.pool_size,
sort_type=args.sort_type,
start_mark=args.special_token[0],
end_mark=args.special_token[1],
unk_mark=args.special_token[2],
# count start and end tokens out
max_length=ModelHyperParams.max_length - 2,
clip_last_batch=False,
shuffle=False,
shuffle_batch=False)
lr_scheduler = LearningRateScheduler(ModelHyperParams.d_model, test_exe = fluid.ParallelExecutor(
TrainTaskConfig.warmup_steps, use_cuda=TrainTaskConfig.use_gpu,
TrainTaskConfig.learning_rate) main_program=test_program,
optimizer = fluid.optimizer.Adam( share_vars_from=train_exe)
learning_rate=lr_scheduler.learning_rate,
beta1=TrainTaskConfig.beta1, def test(exe=test_exe):
beta2=TrainTaskConfig.beta2, test_total_cost = 0
epsilon=TrainTaskConfig.eps) test_total_token = 0
optimizer.minimize(sum_cost) test_data = read_multiple(
reader=val_data.batch_generator,
place = fluid.CUDAPlace(0) if TrainTaskConfig.use_gpu else fluid.CPUPlace() count=dev_count if args.use_token_batch else 1)
exe = fluid.Executor(place) for batch_id, data in enumerate(test_data()):
feed_list = []
for place_id, data_buffer in enumerate(
split_data(
data, num_part=dev_count)):
data_input_dict, util_input_dict, _ = prepare_batch_input(
data_buffer, data_input_names, util_input_names,
ModelHyperParams.eos_idx, ModelHyperParams.eos_idx,
ModelHyperParams.n_head, ModelHyperParams.d_model)
feed_list.append(
dict(data_input_dict.items() + util_input_dict.items()))
outs = exe.run(feed=feed_list,
fetch_list=[sum_cost.name, token_num.name])
sum_cost_val, token_num_val = np.array(outs[0]), np.array(outs[1])
test_total_cost += sum_cost_val.sum()
test_total_token += token_num_val.sum()
test_avg_cost = test_total_cost / test_total_token
test_ppl = np.exp([min(test_avg_cost, 100)])
return test_avg_cost, test_ppl
return test
def train_loop(exe, train_progm, dev_count, sum_cost, avg_cost, lr_scheduler,
token_num, predict):
# Initialize the parameters. # Initialize the parameters.
if TrainTaskConfig.ckpt_path: if TrainTaskConfig.ckpt_path:
fluid.io.load_persistables(exe, TrainTaskConfig.ckpt_path) fluid.io.load_persistables(exe, TrainTaskConfig.ckpt_path)
lr_scheduler.current_steps = TrainTaskConfig.start_step lr_scheduler.current_steps = TrainTaskConfig.start_step
else: else:
print "init fluid.framework.default_startup_program"
exe.run(fluid.framework.default_startup_program()) exe.run(fluid.framework.default_startup_program())
train_data = reader.DataReader( train_data = reader.DataReader(
...@@ -305,77 +359,24 @@ def train(args): ...@@ -305,77 +359,24 @@ def train(args):
train_exe = fluid.ParallelExecutor( train_exe = fluid.ParallelExecutor(
use_cuda=TrainTaskConfig.use_gpu, use_cuda=TrainTaskConfig.use_gpu,
loss_name=sum_cost.name, loss_name=sum_cost.name,
main_program=train_progm,
build_strategy=build_strategy) build_strategy=build_strategy)
def test_context():
# Context to do validation.
test_program = fluid.default_main_program().clone(for_test=True)
test_exe = fluid.ParallelExecutor(
use_cuda=TrainTaskConfig.use_gpu,
main_program=test_program,
share_vars_from=train_exe)
val_data = reader.DataReader(
src_vocab_fpath=args.src_vocab_fpath,
trg_vocab_fpath=args.trg_vocab_fpath,
fpattern=args.val_file_pattern,
use_token_batch=args.use_token_batch,
batch_size=args.batch_size *
(1 if args.use_token_batch else dev_count),
pool_size=args.pool_size,
sort_type=args.sort_type,
start_mark=args.special_token[0],
end_mark=args.special_token[1],
unk_mark=args.special_token[2],
# count start and end tokens out
max_length=ModelHyperParams.max_length - 2,
clip_last_batch=False,
shuffle=False,
shuffle_batch=False)
def test(exe=test_exe):
test_total_cost = 0
test_total_token = 0
test_data = read_multiple(
reader=val_data.batch_generator,
count=dev_count if args.use_token_batch else 1)
for batch_id, data in enumerate(test_data()):
feed_list = []
for place_id, data_buffer in enumerate(
split_data(
data, num_part=dev_count)):
data_input_dict, util_input_dict, _ = prepare_batch_input(
data_buffer, data_input_names, util_input_names,
ModelHyperParams.eos_idx, ModelHyperParams.eos_idx,
ModelHyperParams.n_head, ModelHyperParams.d_model)
feed_list.append(
dict(data_input_dict.items() + util_input_dict.items()))
outs = exe.run(feed=feed_list,
fetch_list=[sum_cost.name, token_num.name])
sum_cost_val, token_num_val = np.array(outs[0]), np.array(outs[
1])
test_total_cost += sum_cost_val.sum()
test_total_token += token_num_val.sum()
test_avg_cost = test_total_cost / test_total_token
test_ppl = np.exp([min(test_avg_cost, 100)])
return test_avg_cost, test_ppl
return test
if args.val_file_pattern is not None:
test = test_context()
data_input_names = encoder_data_input_fields + decoder_data_input_fields[: data_input_names = encoder_data_input_fields + decoder_data_input_fields[:
-1] + label_data_input_fields -1] + label_data_input_fields
util_input_names = encoder_util_input_fields + decoder_util_input_fields util_input_names = encoder_util_input_fields + decoder_util_input_fields
if args.val_file_pattern is not None:
test = test_context(train_progm, avg_cost, train_exe, dev_count,
data_input_names, util_input_names, sum_cost,
token_num)
init = False init = False
for pass_id in xrange(TrainTaskConfig.pass_num): for pass_id in xrange(TrainTaskConfig.pass_num):
pass_start_time = time.time() pass_start_time = time.time()
for batch_id, data in enumerate(train_data()): for batch_id, data in enumerate(train_data()):
feed_list = [] feed_list = []
total_num_token = 0 total_num_token = 0
lr_rate = lr_scheduler.update_learning_rate()
for place_id, data_buffer in enumerate( for place_id, data_buffer in enumerate(
split_data( split_data(
data, num_part=dev_count)): data, num_part=dev_count)):
...@@ -384,11 +385,16 @@ def train(args): ...@@ -384,11 +385,16 @@ def train(args):
ModelHyperParams.eos_idx, ModelHyperParams.eos_idx, ModelHyperParams.eos_idx, ModelHyperParams.eos_idx,
ModelHyperParams.n_head, ModelHyperParams.d_model) ModelHyperParams.n_head, ModelHyperParams.d_model)
total_num_token += num_token total_num_token += num_token
feed_list.append( feed_kv_pairs = data_input_dict.items() + util_input_dict.items(
dict(data_input_dict.items() + util_input_dict.items() + )
{lr_scheduler.learning_rate.name: lr_rate}.items())) if args.local:
lr_rate = lr_scheduler.update_learning_rate()
if not init: # init the position encoding table feed_kv_pairs += {
lr_scheduler.learning_rate.name: lr_rate
}.items()
feed_list.append(dict(feed_kv_pairs))
if not init:
for pos_enc_param_name in pos_enc_param_names: for pos_enc_param_name in pos_enc_param_names:
pos_enc = position_encoding_init( pos_enc = position_encoding_init(
ModelHyperParams.max_length + 1, ModelHyperParams.max_length + 1,
...@@ -408,10 +414,10 @@ def train(args): ...@@ -408,10 +414,10 @@ def train(args):
np.exp([min(total_avg_cost, 100)]))) np.exp([min(total_avg_cost, 100)])))
init = True init = True
# Validate and save the model for inference. # Validate and save the model for inference.
print("epoch: %d, " % pass_id + ( print("epoch: %d, " % pass_id +
"val avg loss: %f, val ppl: %f, " % test() ("val avg loss: %f, val ppl: %f, " % test()
if args.val_file_pattern is not None else "") + "consumed %fs" % ( if args.val_file_pattern is not None else "") + "consumed %fs" %
time.time() - pass_start_time)) (time.time() - pass_start_time))
fluid.io.save_persistables( fluid.io.save_persistables(
exe, exe,
os.path.join(TrainTaskConfig.ckpt_dir, os.path.join(TrainTaskConfig.ckpt_dir,
...@@ -422,6 +428,107 @@ def train(args): ...@@ -422,6 +428,107 @@ def train(args):
data_input_names[:-2] + util_input_names, [predict], exe) data_input_names[:-2] + util_input_names, [predict], exe)
def train(args):
# priority: ENV > args > config
is_local = os.getenv("PADDLE_IS_LOCAL", "1")
if is_local == '0':
args.local = False
print args
if args.device == 'CPU':
TrainTaskConfig.use_gpu = False
training_role = os.getenv("TRAINING_ROLE", "TRAINER")
if training_role == "PSERVER" or (not TrainTaskConfig.use_gpu):
place = fluid.CPUPlace()
dev_count = int(os.environ.get('CPU_NUM', multiprocessing.cpu_count()))
else:
place = fluid.CUDAPlace(0)
dev_count = fluid.core.get_cuda_device_count()
exe = fluid.Executor(place)
sum_cost, avg_cost, predict, token_num = transformer(
ModelHyperParams.src_vocab_size, ModelHyperParams.trg_vocab_size,
ModelHyperParams.max_length + 1, ModelHyperParams.n_layer,
ModelHyperParams.n_head, ModelHyperParams.d_key,
ModelHyperParams.d_value, ModelHyperParams.d_model,
ModelHyperParams.d_inner_hid, ModelHyperParams.dropout,
ModelHyperParams.weight_sharing, TrainTaskConfig.label_smooth_eps)
lr_scheduler = LearningRateScheduler(ModelHyperParams.d_model,
TrainTaskConfig.warmup_steps,
TrainTaskConfig.learning_rate)
if args.local:
optimizer = fluid.optimizer.Adam(
learning_rate=lr_scheduler.learning_rate,
beta1=TrainTaskConfig.beta1,
beta2=TrainTaskConfig.beta2,
epsilon=TrainTaskConfig.eps)
optimizer.minimize(sum_cost)
elif args.sync == False:
optimizer = fluid.optimizer.SGD(0.003)
optimizer.minimize(sum_cost)
else:
lr_decay = fluid.layers\
.learning_rate_scheduler\
.noam_decay(ModelHyperParams.d_model,
TrainTaskConfig.warmup_steps)
optimizer = fluid.optimizer.Adam(
learning_rate=lr_decay,
beta1=TrainTaskConfig.beta1,
beta2=TrainTaskConfig.beta2,
epsilon=TrainTaskConfig.eps)
optimizer.minimize(sum_cost)
if args.local:
print("local start_up:")
train_loop(exe,
fluid.default_main_program(), dev_count, sum_cost, avg_cost,
lr_scheduler, token_num, predict)
else:
port = os.getenv("PADDLE_PORT", "6174")
pserver_ips = os.getenv("PADDLE_PSERVERS") # ip,ip...
eplist = []
for ip in pserver_ips.split(","):
eplist.append(':'.join([ip, port]))
pserver_endpoints = ",".join(eplist) # ip:port,ip:port...
trainers = int(os.getenv("PADDLE_TRAINERS_NUM", "0"))
current_endpoint = os.getenv("POD_IP") + ":" + port
trainer_id = int(os.getenv("PADDLE_TRAINER_ID"))
t = fluid.DistributeTranspiler()
t.transpile(trainer_id, pservers=pserver_endpoints, trainers=trainers)
if training_role == "PSERVER":
current_endpoint = os.getenv("POD_IP") + ":" + os.getenv(
"PADDLE_PORT")
if not current_endpoint:
print("need env SERVER_ENDPOINT")
exit(1)
pserver_prog = t.get_pserver_program(current_endpoint)
pserver_startup = t.get_startup_program(current_endpoint,
pserver_prog)
print "psserver begin run"
with open('pserver_startup.desc', 'w') as f:
f.write(str(pserver_startup))
with open('pserver_prog.desc', 'w') as f:
f.write(str(pserver_prog))
exe.run(pserver_startup)
exe.run(pserver_prog)
elif training_role == "TRAINER":
trainer_prog = t.get_trainer_program()
with open('trainer_prog.desc', 'w') as f:
f.write(str(trainer_prog))
train_loop(exe, trainer_prog, dev_count, sum_cost, avg_cost,
lr_scheduler, token_num, predict)
else:
print("environment var TRAINER_ROLE should be TRAINER os PSERVER")
if __name__ == "__main__": if __name__ == "__main__":
args = parse_args() args = parse_args()
train(args) train(args)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册