提交 a28f756b 编写于 作者: Q qiuxuezhong

Distributed training for transformer, local training as default

上级 55849d4e
......@@ -78,6 +78,12 @@ def parse_args():
help='See config.py for all options',
default=None,
nargs=argparse.REMAINDER)
parser.add_argument(
'--local',
type=ast.literal_eval,
default=True,
help='Whether to run as local mode.')
args = parser.parse_args()
# Append args related to dict
src_dict = reader.DataReader.load_dict(args.src_vocab_fpath)
......@@ -204,49 +210,23 @@ def prepare_batch_input(insts, data_input_names, util_input_names, src_pad_idx,
def train(args):
dev_count = fluid.core.get_cuda_device_count()
is_local = os.getenv("PADDLE_IS_LOCAL", "1")
if is_local == '0':
args.local = False
else:
args.local = True
print args
def read_multiple(reader,
count=dev_count if args.use_token_batch else 1,
clip_last=True):
"""
Stack data from reader for multi-devices.
"""
def __impl__():
res = []
for item in reader():
res.append(item)
if len(res) == count:
yield res
res = []
if len(res) == count:
yield res
elif not clip_last:
data = []
for item in res:
data += item
if len(data) > count:
inst_num_per_part = len(data) // count
yield [
data[inst_num_per_part * i:inst_num_per_part * (i + 1)]
for i in range(count)
]
return __impl__
def split_data(data, num_part=dev_count):
"""
Split data for each device.
"""
if len(data) == num_part:
return data
data = data[0]
inst_num_per_part = len(data) // num_part
return [
data[inst_num_per_part * i:inst_num_per_part * (i + 1)]
for i in range(num_part)
]
training_role = os.getenv("TRAINING_ROLE", "TRAINER")
if training_role == "PSERVER":
place = fluid.CPUPlace()
else:
place = fluid.CUDAPlace(
0) if TrainTaskConfig.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
dev_count = fluid.core.get_cuda_device_count()
sum_cost, avg_cost, predict, token_num = transformer(
ModelHyperParams.src_vocab_size, ModelHyperParams.trg_vocab_size,
......@@ -266,152 +246,234 @@ def train(args):
epsilon=TrainTaskConfig.eps)
optimizer.minimize(sum_cost)
place = fluid.CUDAPlace(0) if TrainTaskConfig.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
# Initialize the parameters.
if TrainTaskConfig.ckpt_path:
fluid.io.load_persistables(exe, TrainTaskConfig.ckpt_path)
lr_scheduler.current_steps = TrainTaskConfig.start_step
else:
exe.run(fluid.framework.default_startup_program())
train_data = reader.DataReader(
src_vocab_fpath=args.src_vocab_fpath,
trg_vocab_fpath=args.trg_vocab_fpath,
fpattern=args.train_file_pattern,
use_token_batch=args.use_token_batch,
batch_size=args.batch_size * (1 if args.use_token_batch else dev_count),
pool_size=args.pool_size,
sort_type=args.sort_type,
shuffle=args.shuffle,
shuffle_batch=args.shuffle_batch,
start_mark=args.special_token[0],
end_mark=args.special_token[1],
unk_mark=args.special_token[2],
clip_last_batch=False)
train_data = read_multiple(reader=train_data.batch_generator)
build_strategy = fluid.BuildStrategy()
# Since the token number differs among devices, customize gradient scale to
# use token average cost among multi-devices. and the gradient scale is
# `1 / token_number` for average cost.
build_strategy.gradient_scale_strategy = fluid.BuildStrategy.GradientScaleStrategy.Customized
train_exe = fluid.ParallelExecutor(
use_cuda=TrainTaskConfig.use_gpu,
loss_name=sum_cost.name,
build_strategy=build_strategy)
def test_context():
# Context to do validation.
test_program = fluid.default_main_program().clone()
with fluid.program_guard(test_program):
test_program = fluid.io.get_inference_program([avg_cost])
val_data = reader.DataReader(
def train_loop(exe, train_progm):
def read_multiple(reader,
count=dev_count if args.use_token_batch else 1,
clip_last=True):
"""
Stack data from reader for multi-devices.
"""
def __impl__():
res = []
for item in reader():
res.append(item)
if len(res) == count:
yield res
res = []
if len(res) == count:
yield res
elif not clip_last:
data = []
for item in res:
data += item
if len(data) > count:
inst_num_per_part = len(data) // count
yield [
data[inst_num_per_part * i:inst_num_per_part * (i +
1)]
for i in range(count)
]
return __impl__
def split_data(data, num_part=dev_count):
"""
Split data for each device.
"""
if len(data) == num_part:
return data
data = data[0]
inst_num_per_part = len(data) // num_part
return [
data[inst_num_per_part * i:inst_num_per_part * (i + 1)]
for i in range(num_part)
]
# Initialize the parameters.
if TrainTaskConfig.ckpt_path:
fluid.io.load_persistables(exe, TrainTaskConfig.ckpt_path)
lr_scheduler.current_steps = TrainTaskConfig.start_step
else:
exe.run(fluid.framework.default_startup_program())
train_data = reader.DataReader(
src_vocab_fpath=args.src_vocab_fpath,
trg_vocab_fpath=args.trg_vocab_fpath,
fpattern=args.val_file_pattern,
fpattern=args.train_file_pattern,
use_token_batch=args.use_token_batch,
batch_size=args.batch_size *
(1 if args.use_token_batch else dev_count),
pool_size=args.pool_size,
sort_type=args.sort_type,
shuffle=args.shuffle,
shuffle_batch=args.shuffle_batch,
start_mark=args.special_token[0],
end_mark=args.special_token[1],
unk_mark=args.special_token[2],
clip_last_batch=False,
shuffle=False,
shuffle_batch=False)
test_exe = fluid.ParallelExecutor(
clip_last_batch=False)
train_data = read_multiple(reader=train_data.batch_generator)
build_strategy = fluid.BuildStrategy()
# Since the token number differs among devices, customize gradient scale to
# use token average cost among multi-devices. and the gradient scale is
# `1 / token_number` for average cost.
build_strategy.gradient_scale_strategy = fluid.BuildStrategy.GradientScaleStrategy.Customized
train_exe = fluid.ParallelExecutor(
use_cuda=TrainTaskConfig.use_gpu,
main_program=test_program,
share_vars_from=train_exe)
def test(exe=test_exe):
test_total_cost = 0
test_total_token = 0
test_data = read_multiple(reader=val_data.batch_generator)
for batch_id, data in enumerate(test_data()):
loss_name=sum_cost.name,
main_program=train_progm,
build_strategy=build_strategy)
def test_context():
# Context to do validation.
test_program = train_progm.clone()
with fluid.program_guard(test_program):
test_program = fluid.io.get_inference_program([avg_cost])
val_data = reader.DataReader(
src_vocab_fpath=args.src_vocab_fpath,
trg_vocab_fpath=args.trg_vocab_fpath,
fpattern=args.val_file_pattern,
use_token_batch=args.use_token_batch,
batch_size=args.batch_size *
(1 if args.use_token_batch else dev_count),
pool_size=args.pool_size,
sort_type=args.sort_type,
start_mark=args.special_token[0],
end_mark=args.special_token[1],
unk_mark=args.special_token[2],
clip_last_batch=False,
shuffle=False,
shuffle_batch=False)
test_exe = fluid.ParallelExecutor(
use_cuda=TrainTaskConfig.use_gpu,
main_program=test_program,
share_vars_from=train_exe)
def test(exe=test_exe):
test_total_cost = 0
test_total_token = 0
test_data = read_multiple(reader=val_data.batch_generator)
for batch_id, data in enumerate(test_data()):
feed_list = []
for place_id, data_buffer in enumerate(split_data(data)):
data_input_dict, util_input_dict, _ = prepare_batch_input(
data_buffer, data_input_names, util_input_names,
ModelHyperParams.eos_idx, ModelHyperParams.eos_idx,
ModelHyperParams.n_head, ModelHyperParams.d_model)
feed_list.append(
dict(data_input_dict.items() +
util_input_dict.items()))
outs = exe.run(feed=feed_list,
fetch_list=[sum_cost.name, token_num.name])
sum_cost_val, token_num_val = np.array(outs[0]), np.array(
outs[1])
test_total_cost += sum_cost_val.sum()
test_total_token += token_num_val.sum()
test_avg_cost = test_total_cost / test_total_token
test_ppl = np.exp([min(test_avg_cost, 100)])
return test_avg_cost, test_ppl
return test
if args.val_file_pattern is not None:
test = test_context()
data_input_names = encoder_data_input_fields + decoder_data_input_fields[:
-1] + label_data_input_fields
util_input_names = encoder_util_input_fields + decoder_util_input_fields
init = False
for pass_id in xrange(TrainTaskConfig.pass_num):
pass_start_time = time.time()
for batch_id, data in enumerate(train_data()):
feed_list = []
total_num_token = 0
lr_rate = lr_scheduler.update_learning_rate()
for place_id, data_buffer in enumerate(split_data(data)):
data_input_dict, util_input_dict, _ = prepare_batch_input(
data_input_dict, util_input_dict, num_token = prepare_batch_input(
data_buffer, data_input_names, util_input_names,
ModelHyperParams.eos_idx, ModelHyperParams.eos_idx,
ModelHyperParams.n_head, ModelHyperParams.d_model)
total_num_token += num_token
feed_list.append(
dict(data_input_dict.items() + util_input_dict.items()))
outs = exe.run(feed=feed_list,
fetch_list=[sum_cost.name, token_num.name])
dict(data_input_dict.items() + util_input_dict.items(
) + {lr_scheduler.learning_rate.name: lr_rate}.items()))
if not init:
for pos_enc_param_name in pos_enc_param_names:
pos_enc = position_encoding_init(
ModelHyperParams.max_length + 1,
ModelHyperParams.d_model)
feed_list[place_id][pos_enc_param_name] = pos_enc
for feed_dict in feed_list:
feed_dict[
sum_cost.name +
"@GRAD"] = 1. / total_num_token if TrainTaskConfig.use_avg_cost else np.asarray(
[1.], dtype="float32")
outs = train_exe.run(
fetch_list=[sum_cost.name, token_num.name], feed=feed_list)
sum_cost_val, token_num_val = np.array(outs[0]), np.array(outs[
1])
test_total_cost += sum_cost_val.sum()
test_total_token += token_num_val.sum()
test_avg_cost = test_total_cost / test_total_token
test_ppl = np.exp([min(test_avg_cost, 100)])
return test_avg_cost, test_ppl
return test
if args.val_file_pattern is not None:
test = test_context()
data_input_names = encoder_data_input_fields + decoder_data_input_fields[:
-1] + label_data_input_fields
util_input_names = encoder_util_input_fields + decoder_util_input_fields
init = False
for pass_id in xrange(TrainTaskConfig.pass_num):
pass_start_time = time.time()
for batch_id, data in enumerate(train_data()):
feed_list = []
total_num_token = 0
lr_rate = lr_scheduler.update_learning_rate()
for place_id, data_buffer in enumerate(split_data(data)):
data_input_dict, util_input_dict, num_token = prepare_batch_input(
data_buffer, data_input_names, util_input_names,
ModelHyperParams.eos_idx, ModelHyperParams.eos_idx,
ModelHyperParams.n_head, ModelHyperParams.d_model)
total_num_token += num_token
feed_list.append(
dict(data_input_dict.items() + util_input_dict.items() +
{lr_scheduler.learning_rate.name: lr_rate}.items()))
if not init:
for pos_enc_param_name in pos_enc_param_names:
pos_enc = position_encoding_init(
ModelHyperParams.max_length + 1,
ModelHyperParams.d_model)
feed_list[place_id][pos_enc_param_name] = pos_enc
for feed_dict in feed_list:
feed_dict[
sum_cost.name +
"@GRAD"] = 1. / total_num_token if TrainTaskConfig.use_avg_cost else np.asarray(
[1.], dtype="float32")
outs = train_exe.run(fetch_list=[sum_cost.name, token_num.name],
feed=feed_list)
sum_cost_val, token_num_val = np.array(outs[0]), np.array(outs[1])
total_sum_cost = sum_cost_val.sum(
) # sum the cost from multi-devices
total_token_num = token_num_val.sum()
total_avg_cost = total_sum_cost / total_token_num
print("epoch: %d, batch: %d, sum loss: %f, avg loss: %f, ppl: %f" %
(pass_id, batch_id, total_sum_cost, total_avg_cost,
np.exp([min(total_avg_cost, 100)])))
init = True
# Validate and save the model for inference.
print("epoch: %d, " % pass_id + (
"val avg loss: %f, val ppl: %f, " % test()
if args.val_file_pattern is not None else "") + "consumed %fs" % (
time.time() - pass_start_time))
fluid.io.save_persistables(
exe,
os.path.join(TrainTaskConfig.ckpt_dir,
"pass_" + str(pass_id) + ".checkpoint"))
fluid.io.save_inference_model(
os.path.join(TrainTaskConfig.model_dir,
"pass_" + str(pass_id) + ".infer.model"),
data_input_names[:-2] + util_input_names, [predict], exe)
total_sum_cost = sum_cost_val.sum(
) # sum the cost from multi-devices
total_token_num = token_num_val.sum()
total_avg_cost = total_sum_cost / total_token_num
print(
"epoch: %d, batch: %d, sum loss: %f, avg loss: %f, ppl: %f"
% (pass_id, batch_id, total_sum_cost, total_avg_cost,
np.exp([min(total_avg_cost, 100)])))
init = True
# Validate and save the model for inference.
print("epoch: %d, " % pass_id + (
"val avg loss: %f, val ppl: %f, " % test()
if args.val_file_pattern is not None else "") + "consumed %fs" %
(time.time() - pass_start_time))
fluid.io.save_persistables(
exe,
os.path.join(TrainTaskConfig.ckpt_dir,
"pass_" + str(pass_id) + ".checkpoint"))
fluid.io.save_inference_model(
os.path.join(TrainTaskConfig.model_dir,
"pass_" + str(pass_id) + ".infer.model"),
data_input_names[:-2] + util_input_names, [predict], exe)
if args.local:
print("local start_up:")
train_loop(exe, fluid.default_main_program())
else:
port = os.getenv("PADDLE_PORT", "6174")
pserver_ips = os.getenv("PADDLE_PSERVERS") # ip,ip...
eplist = []
for ip in pserver_ips.split(","):
eplist.append(':'.join([ip, port]))
pserver_endpoints = ",".join(eplist) # ip:port,ip:port...
trainers = int(os.getenv("PADDLE_TRAINERS_NUM", "0"))
current_endpoint = os.getenv("POD_IP") + ":" + port
trainer_id = int(os.getenv("PADDLE_TRAINER_ID"))
t = fluid.DistributeTranspiler()
t.transpile(trainer_id, pservers=pserver_endpoints, trainers=trainers)
if training_role == "PSERVER":
current_endpoint = os.getenv("POD_IP") + ":" + os.getenv(
"PADDLE_PORT")
if not current_endpoint:
print("need env SERVER_ENDPOINT")
exit(1)
pserver_prog = t.get_pserver_program(current_endpoint)
pserver_startup = t.get_startup_program(current_endpoint,
pserver_prog)
print "psserver begin run"
exe.run(pserver_startup)
exe.run(pserver_prog)
elif training_role == "TRAINER":
trainer_prog = t.get_trainer_program()
train_loop(exe, trainer_prog)
else:
print("environment var TRAINER_ROLE should be TRAINER os PSERVER")
if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册