提交 87c5cf65 编写于 作者: Q Qiao Longfei

distributed trainning for transformer

上级 8761ab3d
...@@ -3,6 +3,7 @@ import time ...@@ -3,6 +3,7 @@ import time
import argparse import argparse
import ast import ast
import numpy as np import numpy as np
import multiprocessing
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
...@@ -80,6 +81,18 @@ def parse_args(): ...@@ -80,6 +81,18 @@ def parse_args():
help='See config.py for all options', help='See config.py for all options',
default=None, default=None,
nargs=argparse.REMAINDER) nargs=argparse.REMAINDER)
parser.add_argument(
'--local',
type=ast.literal_eval,
default=True,
help='Whether to run as local mode.')
parser.add_argument(
'--device',
type=str,
default='GPU',
choices=['CPU', 'GPU'],
help="The device type.")
args = parser.parse_args() args = parser.parse_args()
# Append args related to dict # Append args related to dict
src_dict = reader.DataReader.load_dict(args.src_vocab_fpath) src_dict = reader.DataReader.load_dict(args.src_vocab_fpath)
...@@ -205,7 +218,61 @@ def prepare_batch_input(insts, data_input_names, util_input_names, src_pad_idx, ...@@ -205,7 +218,61 @@ def prepare_batch_input(insts, data_input_names, util_input_names, src_pad_idx,
[num_token], dtype="float32") [num_token], dtype="float32")
def read_multiple(reader, count, clip_last=True): def train(args):
# priority: ENV > args > config
is_local = os.getenv("PADDLE_IS_LOCAL", "1")
if is_local == '0':
args.local = False
print args
if args.device == 'CPU':
TrainTaskConfig.use_gpu = False
training_role = os.getenv("TRAINING_ROLE", "TRAINER")
if training_role == "PSERVER" or (not TrainTaskConfig.use_gpu):
place = fluid.CPUPlace()
dev_count = int(os.environ.get('CPU_NUM', multiprocessing.cpu_count()))
else:
place = fluid.CUDAPlace(0)
dev_count = fluid.core.get_cuda_device_count()
exe = fluid.Executor(place)
sum_cost, avg_cost, predict, token_num = transformer(
ModelHyperParams.src_vocab_size, ModelHyperParams.trg_vocab_size,
ModelHyperParams.max_length + 1, ModelHyperParams.n_layer,
ModelHyperParams.n_head, ModelHyperParams.d_key,
ModelHyperParams.d_value, ModelHyperParams.d_model,
ModelHyperParams.d_inner_hid, ModelHyperParams.dropout,
ModelHyperParams.weight_sharing, TrainTaskConfig.label_smooth_eps)
if args.local:
lr_scheduler = LearningRateScheduler(ModelHyperParams.d_model,
TrainTaskConfig.warmup_steps,
TrainTaskConfig.learning_rate)
optimizer = fluid.optimizer.Adam(
learning_rate=lr_scheduler.learning_rate,
beta1=TrainTaskConfig.beta1,
beta2=TrainTaskConfig.beta2,
epsilon=TrainTaskConfig.eps)
optimizer.minimize(sum_cost)
else:
lr_decay = fluid.layers\
.learning_rate_scheduler\
.noam_decay(ModelHyperParams.d_model,
TrainTaskConfig.warmup_steps)
optimizer = fluid.optimizer.Adam(
learning_rate=lr_decay,
beta1=TrainTaskConfig.beta1,
beta2=TrainTaskConfig.beta2,
epsilon=TrainTaskConfig.eps)
optimizer.minimize(sum_cost)
def train_loop(exe, train_progm):
def read_multiple(reader,
count=dev_count if args.use_token_batch else 1,
clip_last=True):
""" """
Stack data from reader for multi-devices. Stack data from reader for multi-devices.
""" """
...@@ -226,14 +293,14 @@ def read_multiple(reader, count, clip_last=True): ...@@ -226,14 +293,14 @@ def read_multiple(reader, count, clip_last=True):
if len(data) > count: if len(data) > count:
inst_num_per_part = len(data) // count inst_num_per_part = len(data) // count
yield [ yield [
data[inst_num_per_part * i:inst_num_per_part * (i + 1)] data[inst_num_per_part * i:inst_num_per_part * (i +
1)]
for i in range(count) for i in range(count)
] ]
return __impl__ return __impl__
def split_data(data, num_part=dev_count):
def split_data(data, num_part):
""" """
Split data for each device. Split data for each device.
""" """
...@@ -246,35 +313,12 @@ def split_data(data, num_part): ...@@ -246,35 +313,12 @@ def split_data(data, num_part):
for i in range(num_part) for i in range(num_part)
] ]
def train(args):
dev_count = fluid.core.get_cuda_device_count()
sum_cost, avg_cost, predict, token_num = transformer(
ModelHyperParams.src_vocab_size, ModelHyperParams.trg_vocab_size,
ModelHyperParams.max_length + 1, ModelHyperParams.n_layer,
ModelHyperParams.n_head, ModelHyperParams.d_key,
ModelHyperParams.d_value, ModelHyperParams.d_model,
ModelHyperParams.d_inner_hid, ModelHyperParams.dropout,
ModelHyperParams.weight_sharing, TrainTaskConfig.label_smooth_eps)
lr_scheduler = LearningRateScheduler(ModelHyperParams.d_model,
TrainTaskConfig.warmup_steps,
TrainTaskConfig.learning_rate)
optimizer = fluid.optimizer.Adam(
learning_rate=lr_scheduler.learning_rate,
beta1=TrainTaskConfig.beta1,
beta2=TrainTaskConfig.beta2,
epsilon=TrainTaskConfig.eps)
optimizer.minimize(sum_cost)
place = fluid.CUDAPlace(0) if TrainTaskConfig.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
# Initialize the parameters. # Initialize the parameters.
if TrainTaskConfig.ckpt_path: if TrainTaskConfig.ckpt_path:
fluid.io.load_persistables(exe, TrainTaskConfig.ckpt_path) fluid.io.load_persistables(exe, TrainTaskConfig.ckpt_path)
lr_scheduler.current_steps = TrainTaskConfig.start_step #lr_scheduler.current_steps = TrainTaskConfig.start_step
else: else:
print "init fluid.framework.default_startup_program"
exe.run(fluid.framework.default_startup_program()) exe.run(fluid.framework.default_startup_program())
train_data = reader.DataReader( train_data = reader.DataReader(
...@@ -282,7 +326,8 @@ def train(args): ...@@ -282,7 +326,8 @@ def train(args):
trg_vocab_fpath=args.trg_vocab_fpath, trg_vocab_fpath=args.trg_vocab_fpath,
fpattern=args.train_file_pattern, fpattern=args.train_file_pattern,
use_token_batch=args.use_token_batch, use_token_batch=args.use_token_batch,
batch_size=args.batch_size * (1 if args.use_token_batch else dev_count), batch_size=args.batch_size *
(1 if args.use_token_batch else dev_count),
pool_size=args.pool_size, pool_size=args.pool_size,
sort_type=args.sort_type, sort_type=args.sort_type,
shuffle=args.shuffle, shuffle=args.shuffle,
...@@ -290,12 +335,9 @@ def train(args): ...@@ -290,12 +335,9 @@ def train(args):
start_mark=args.special_token[0], start_mark=args.special_token[0],
end_mark=args.special_token[1], end_mark=args.special_token[1],
unk_mark=args.special_token[2], unk_mark=args.special_token[2],
max_length=ModelHyperParams.max_length,
clip_last_batch=False) clip_last_batch=False)
train_data = read_multiple(
reader=train_data.batch_generator,
count=dev_count if args.use_token_batch else 1)
train_data = read_multiple(reader=train_data.batch_generator)
build_strategy = fluid.BuildStrategy() build_strategy = fluid.BuildStrategy()
# Since the token number differs among devices, customize gradient scale to # Since the token number differs among devices, customize gradient scale to
# use token average cost among multi-devices. and the gradient scale is # use token average cost among multi-devices. and the gradient scale is
...@@ -304,15 +346,14 @@ def train(args): ...@@ -304,15 +346,14 @@ def train(args):
train_exe = fluid.ParallelExecutor( train_exe = fluid.ParallelExecutor(
use_cuda=TrainTaskConfig.use_gpu, use_cuda=TrainTaskConfig.use_gpu,
loss_name=sum_cost.name, loss_name=sum_cost.name,
main_program=train_progm,
build_strategy=build_strategy) build_strategy=build_strategy)
def test_context(): def test_context():
# Context to do validation. # Context to do validation.
test_program = fluid.default_main_program().clone(for_test=True) test_program = train_progm.clone()
test_exe = fluid.ParallelExecutor( with fluid.program_guard(test_program):
use_cuda=TrainTaskConfig.use_gpu, test_program = fluid.io.get_inference_program([avg_cost])
main_program=test_program,
share_vars_from=train_exe)
val_data = reader.DataReader( val_data = reader.DataReader(
src_vocab_fpath=args.src_vocab_fpath, src_vocab_fpath=args.src_vocab_fpath,
...@@ -326,33 +367,34 @@ def train(args): ...@@ -326,33 +367,34 @@ def train(args):
start_mark=args.special_token[0], start_mark=args.special_token[0],
end_mark=args.special_token[1], end_mark=args.special_token[1],
unk_mark=args.special_token[2], unk_mark=args.special_token[2],
max_length=ModelHyperParams.max_length,
clip_last_batch=False, clip_last_batch=False,
shuffle=False, shuffle=False,
shuffle_batch=False) shuffle_batch=False)
test_exe = fluid.ParallelExecutor(
use_cuda=TrainTaskConfig.use_gpu,
main_program=test_program,
share_vars_from=train_exe)
def test(exe=test_exe): def test(exe=test_exe):
test_total_cost = 0 test_total_cost = 0
test_total_token = 0 test_total_token = 0
test_data = read_multiple( test_data = read_multiple(reader=val_data.batch_generator)
reader=val_data.batch_generator,
count=dev_count if args.use_token_batch else 1)
for batch_id, data in enumerate(test_data()): for batch_id, data in enumerate(test_data()):
feed_list = [] feed_list = []
for place_id, data_buffer in enumerate( for place_id, data_buffer in enumerate(split_data(data)):
split_data(
data, num_part=dev_count)):
data_input_dict, util_input_dict, _ = prepare_batch_input( data_input_dict, util_input_dict, _ = prepare_batch_input(
data_buffer, data_input_names, util_input_names, data_buffer, data_input_names, util_input_names,
ModelHyperParams.eos_idx, ModelHyperParams.eos_idx, ModelHyperParams.eos_idx, ModelHyperParams.eos_idx,
ModelHyperParams.n_head, ModelHyperParams.d_model) ModelHyperParams.n_head, ModelHyperParams.d_model)
feed_list.append( feed_list.append(
dict(data_input_dict.items() + util_input_dict.items())) dict(data_input_dict.items() +
util_input_dict.items()))
outs = exe.run(feed=feed_list, outs = exe.run(feed=feed_list,
fetch_list=[sum_cost.name, token_num.name]) fetch_list=[sum_cost.name, token_num.name])
sum_cost_val, token_num_val = np.array(outs[0]), np.array(outs[ sum_cost_val, token_num_val = np.array(outs[0]), np.array(
1]) outs[1])
test_total_cost += sum_cost_val.sum() test_total_cost += sum_cost_val.sum()
test_total_token += token_num_val.sum() test_total_token += token_num_val.sum()
test_avg_cost = test_total_cost / test_total_token test_avg_cost = test_total_cost / test_total_token
...@@ -373,20 +415,22 @@ def train(args): ...@@ -373,20 +415,22 @@ def train(args):
for batch_id, data in enumerate(train_data()): for batch_id, data in enumerate(train_data()):
feed_list = [] feed_list = []
total_num_token = 0 total_num_token = 0
lr_rate = lr_scheduler.update_learning_rate() for place_id, data_buffer in enumerate(split_data(data)):
for place_id, data_buffer in enumerate(
split_data(
data, num_part=dev_count)):
data_input_dict, util_input_dict, num_token = prepare_batch_input( data_input_dict, util_input_dict, num_token = prepare_batch_input(
data_buffer, data_input_names, util_input_names, data_buffer, data_input_names, util_input_names,
ModelHyperParams.eos_idx, ModelHyperParams.eos_idx, ModelHyperParams.eos_idx, ModelHyperParams.eos_idx,
ModelHyperParams.n_head, ModelHyperParams.d_model) ModelHyperParams.n_head, ModelHyperParams.d_model)
total_num_token += num_token total_num_token += num_token
feed_list.append( feed_kv_pairs = data_input_dict.items(
dict(data_input_dict.items() + util_input_dict.items() + ) + util_input_dict.items()
{lr_scheduler.learning_rate.name: lr_rate}.items())) if args.local:
lr_rate = lr_scheduler.update_learning_rate()
feed_kv_pairs += {
lr_scheduler.learning_rate.name: lr_rate
}.items()
feed_list.append(dict(feed_kv_pairs))
if not init: # init the position encoding table if not init:
for pos_enc_param_name in pos_enc_param_names: for pos_enc_param_name in pos_enc_param_names:
pos_enc = position_encoding_init( pos_enc = position_encoding_init(
ModelHyperParams.max_length + 1, ModelHyperParams.max_length + 1,
...@@ -394,22 +438,25 @@ def train(args): ...@@ -394,22 +438,25 @@ def train(args):
feed_list[place_id][pos_enc_param_name] = pos_enc feed_list[place_id][pos_enc_param_name] = pos_enc
for feed_dict in feed_list: for feed_dict in feed_list:
feed_dict[sum_cost.name + "@GRAD"] = 1. / total_num_token feed_dict[sum_cost.name + "@GRAD"] = 1. / total_num_token
outs = train_exe.run(fetch_list=[sum_cost.name, token_num.name], outs = train_exe.run(
feed=feed_list) fetch_list=[sum_cost.name, token_num.name], feed=feed_list)
sum_cost_val, token_num_val = np.array(outs[0]), np.array(outs[1]) train_exe.bcast_params()
sum_cost_val, token_num_val = np.array(outs[0]), np.array(outs[
1])
total_sum_cost = sum_cost_val.sum( total_sum_cost = sum_cost_val.sum(
) # sum the cost from multi-devices ) # sum the cost from multi-devices
total_token_num = token_num_val.sum() total_token_num = token_num_val.sum()
total_avg_cost = total_sum_cost / total_token_num total_avg_cost = total_sum_cost / total_token_num
print("epoch: %d, batch: %d, sum loss: %f, avg loss: %f, ppl: %f" % print(
(pass_id, batch_id, total_sum_cost, total_avg_cost, "epoch: %d, batch: %d, sum loss: %f, avg loss: %f, ppl: %f"
% (pass_id, batch_id, total_sum_cost, total_avg_cost,
np.exp([min(total_avg_cost, 100)]))) np.exp([min(total_avg_cost, 100)])))
init = True init = True
# Validate and save the model for inference. # Validate and save the model for inference.
print("epoch: %d, " % pass_id + ( print("epoch: %d, " % pass_id + (
"val avg loss: %f, val ppl: %f, " % test() "val avg loss: %f, val ppl: %f, " % test()
if args.val_file_pattern is not None else "") + "consumed %fs" % ( if args.val_file_pattern is not None else "") + "consumed %fs" %
time.time() - pass_start_time)) (time.time() - pass_start_time))
fluid.io.save_persistables( fluid.io.save_persistables(
exe, exe,
os.path.join(TrainTaskConfig.ckpt_dir, os.path.join(TrainTaskConfig.ckpt_dir,
...@@ -419,6 +466,48 @@ def train(args): ...@@ -419,6 +466,48 @@ def train(args):
"pass_" + str(pass_id) + ".infer.model"), "pass_" + str(pass_id) + ".infer.model"),
data_input_names[:-2] + util_input_names, [predict], exe) data_input_names[:-2] + util_input_names, [predict], exe)
if args.local:
print("local start_up:")
train_loop(exe, fluid.default_main_program())
else:
port = os.getenv("PADDLE_PORT", "6174")
pserver_ips = os.getenv("PADDLE_PSERVERS") # ip,ip...
eplist = []
for ip in pserver_ips.split(","):
eplist.append(':'.join([ip, port]))
pserver_endpoints = ",".join(eplist) # ip:port,ip:port...
trainers = int(os.getenv("PADDLE_TRAINERS_NUM", "0"))
current_endpoint = os.getenv("POD_IP") + ":" + port
trainer_id = int(os.getenv("PADDLE_TRAINER_ID"))
t = fluid.DistributeTranspiler()
t.transpile(trainer_id, pservers=pserver_endpoints, trainers=trainers)
if training_role == "PSERVER":
current_endpoint = os.getenv("POD_IP") + ":" + os.getenv(
"PADDLE_PORT")
if not current_endpoint:
print("need env SERVER_ENDPOINT")
exit(1)
pserver_prog = t.get_pserver_program(current_endpoint)
pserver_startup = t.get_startup_program(current_endpoint,
pserver_prog)
print "psserver begin run"
with open('pserver_startup.desc', 'w') as f:
f.write(str(pserver_startup))
with open('pserver_prog.desc', 'w') as f:
f.write(str(pserver_prog))
exe.run(pserver_startup)
exe.run(pserver_prog)
elif training_role == "TRAINER":
trainer_prog = t.get_trainer_program()
with open('trainer_prog.desc', 'w') as f:
f.write(str(trainer_prog))
train_loop(exe, trainer_prog)
else:
print("environment var TRAINER_ROLE should be TRAINER os PSERVER")
if __name__ == "__main__": if __name__ == "__main__":
args = parse_args() args = parse_args()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册