提交 3204a7f7 编写于 作者: G gongweibao

fix

上级 579be0c1
......@@ -12,6 +12,8 @@ import reader
from config import *
from model import transformer, position_encoding_init
import logging
import sys
def parse_args():
parser = argparse.ArgumentParser("Training for Transformer.")
......@@ -98,6 +100,11 @@ def parse_args():
default='GPU',
choices=['CPU', 'GPU'],
help="The device type.")
parser.add_argument(
'--update_method',
choices=("pserver", "nccl2"),
default="pserver",
help='Update method.')
parser.add_argument(
'--sync', type=ast.literal_eval, default=True, help="sync mode.")
parser.add_argument(
......@@ -116,6 +123,12 @@ def parse_args():
type=ast.literal_eval,
default=True,
help="The flag indicating whether to use py_reader.")
parser.add_argument(
"--fetch_steps",
type=int,
default=100,
help="Fetch outputs steps.")
args = parser.parse_args()
# Append args related to dict
......@@ -131,6 +144,26 @@ def parse_args():
[TrainTaskConfig, ModelHyperParams])
return args
def append_nccl2_prepare(trainer_id, worker_endpoints, current_endpoint):
assert(trainer_id >= 0 and
len(worker_endpoints) > 1 and
current_endpoint in worker_endpoints)
eps = copy.deepcopy(worker_endpoints)
eps.remove(current_endpoint)
nccl_id_var = fluid.default_startup_program().global_block().create_var(
name="NCCLID",
persistable=True,
type=fluid.core.VarDesc.VarType.RAW)
fluid.default_startup_program().global_block().append_op(
type="gen_nccl_id",
inputs={},
outputs={"NCCLID": nccl_id_var},
attrs={
"endpoint": current_endpoint,
"endpoint_list": eps,
"trainer_id": trainer_id
})
return nccl_id_var
def pad_batch_data(insts,
pad_idx,
......@@ -409,14 +442,15 @@ def test_context(exe, train_exe, dev_count):
def train_loop(exe, train_prog, startup_prog, dev_count, sum_cost, avg_cost,
token_num, predict, pyreader):
token_num, predict, pyreader, nccl2_num_trainers=1, nccl2_trainer_id=0):
# Initialize the parameters.
if TrainTaskConfig.ckpt_path:
fluid.io.load_persistables(exe, TrainTaskConfig.ckpt_path)
else:
print("init fluid.framework.default_startup_program")
logging.info("init fluid.framework.default_startup_program")
exe.run(startup_prog)
logging.info("begin reader")
train_data = prepare_data_generator(
args, is_test=False, count=dev_count, pyreader=pyreader)
......@@ -429,12 +463,19 @@ def train_loop(exe, train_prog, startup_prog, dev_count, sum_cost, avg_cost,
# use token average cost among multi-devices. and the gradient scale is
# `1 / token_number` for average cost.
build_strategy.gradient_scale_strategy = fluid.BuildStrategy.GradientScaleStrategy.Customized
logging.info("begin read executor")
exec_strategy = fluid.ExecutionStrategy()
if args.update_method == "nccl2":
exec_strategy.num_threads = 1
train_exe = fluid.ParallelExecutor(
use_cuda=TrainTaskConfig.use_gpu,
loss_name=avg_cost.name,
main_program=train_prog,
build_strategy=build_strategy,
exec_strategy=exec_strategy)
exec_strategy=exec_strategy,
num_trainers=nccl2_num_trainers, trainer_id=nccl2_trainer_id)
if args.val_file_pattern is not None:
test = test_context(exe, train_exe, dev_count)
......@@ -448,6 +489,8 @@ def train_loop(exe, train_prog, startup_prog, dev_count, sum_cost, avg_cost,
step_idx = 0
init_flag = True
logging.info("begin train")
for pass_id in six.moves.xrange(TrainTaskConfig.pass_num):
pass_start_time = time.time()
......@@ -458,26 +501,30 @@ def train_loop(exe, train_prog, startup_prog, dev_count, sum_cost, avg_cost,
data_generator = train_data()
batch_id = 0
avg_batch_time=time.time()
while True:
try:
feed_dict_list = prepare_feed_dict_list(data_generator,
init_flag, dev_count)
outs = train_exe.run(
fetch_list=[sum_cost.name, token_num.name],
feed=feed_dict_list)
sum_cost_val, token_num_val = np.array(outs[0]), np.array(outs[
1])
# sum the cost from multi-devices
total_sum_cost = sum_cost_val.sum()
total_token_num = token_num_val.sum()
total_avg_cost = total_sum_cost / total_token_num
print("step_idx: %d, epoch: %d, batch: %d, avg loss: %f, "
"normalized loss: %f, ppl: %f" %
(step_idx, pass_id, batch_id, total_avg_cost,
total_avg_cost - loss_normalizer,
np.exp([min(total_avg_cost, 100)])))
fetch_list=[sum_cost.name, token_num.name] if batch_id % args.fetch_steps == 0 else[],
feed=feed_dict_list)
if batch_id % args.fetch_steps == 0 and batch_id > 0:
sum_cost_val, token_num_val = np.array(outs[0]), np.array(outs[
1])
# sum the cost from multi-devices
total_sum_cost = sum_cost_val.sum()
total_token_num = token_num_val.sum()
total_avg_cost = total_sum_cost / total_token_num
logging.info("step_idx: %d, epoch: %d, batch: %d, avg loss: %f, "
"normalized loss: %f, ppl: %f, speed: %.2f step/s" %
(step_idx, pass_id, batch_id, total_avg_cost,
total_avg_cost - loss_normalizer,
np.exp([min(total_avg_cost, 100)]),
args.fetch_steps / (time.time() - avg_batch_time)))
if step_idx % int(TrainTaskConfig.
save_freq) == TrainTaskConfig.save_freq - 1:
......@@ -490,6 +537,8 @@ def train_loop(exe, train_prog, startup_prog, dev_count, sum_cost, avg_cost,
os.path.join(TrainTaskConfig.model_dir,
"iter_" + str(step_idx) + ".infer.model"),
train_prog)
if batch_id % args.fetch_steps == 0 and batch_id > 0:
avg_batch_time=time.time()
init_flag = False
batch_id += 1
step_idx += 1
......@@ -503,13 +552,13 @@ def train_loop(exe, train_prog, startup_prog, dev_count, sum_cost, avg_cost,
# Validate and save the persistable.
if args.val_file_pattern is not None:
val_avg_cost, val_ppl = test()
print(
logging.info(
"epoch: %d, val avg loss: %f, val normalized loss: %f, val ppl: %f,"
" consumed %fs" % (pass_id, val_avg_cost,
val_avg_cost - loss_normalizer, val_ppl,
time_consumed))
else:
print("epoch: %d, consumed %fs" % (pass_id, time_consumed))
logging.info("epoch: %d, consumed %fs" % (pass_id, time_consumed))
fluid.io.save_persistables(
exe,
os.path.join(TrainTaskConfig.ckpt_dir,
......@@ -527,7 +576,7 @@ def train(args):
is_local = os.getenv("PADDLE_IS_LOCAL", "1")
if is_local == '0':
args.local = False
print(args)
logging.info(args)
if args.device == 'CPU':
TrainTaskConfig.use_gpu = False
......@@ -592,6 +641,26 @@ def train(args):
train_loop(exe, train_prog, startup_prog, dev_count, sum_cost, avg_cost,
token_num, predict, pyreader)
else:
if args.update_method == "nccl2":
trainer_id = int(os.getenv("PADDLE_TRAINER_ID", "0"))
port = os.getenv("PADDLE_PORT")
worker_ips = os.getenv("PADDLE_TRAINERS")
worker_endpoints = []
for ip in worker_ips.split(","):
worker_endpoints.append(':'.join([ip, port]))
trainers_num = len(worker_endpoints)
current_endpoint = os.getenv("POD_IP") + ":" + port
if trainer_id == 0:
logging.info("train_id == 0, sleep 60s")
time.sleep(60)
print("trainers_num:", trainers_num)
print("worker_endpoints:", worker_endpoints)
print("current_endpoint:", current_endpoint)
append_nccl2_prepare(trainer_id, worker_endpoints, current_endpoint)
train_loop(exe, fluid.default_main_program(), dev_count, sum_cost, avg_cost,
lr_scheduler, token_num, predict, trainers_num, trainer_id)
return
port = os.getenv("PADDLE_PORT", "6174")
pserver_ips = os.getenv("PADDLE_PSERVERS") # ip,ip...
eplist = []
......@@ -610,6 +679,7 @@ def train(args):
startup_program=startup_prog)
if training_role == "PSERVER":
loggin.info("distributed: pserver started")
current_endpoint = os.getenv("POD_IP") + ":" + os.getenv(
"PADDLE_PORT")
if not current_endpoint:
......@@ -619,23 +689,21 @@ def train(args):
pserver_startup = t.get_startup_program(current_endpoint,
pserver_prog)
print("psserver begin run")
with open('pserver_startup.desc', 'w') as f:
f.write(str(pserver_startup))
with open('pserver_prog.desc', 'w') as f:
f.write(str(pserver_prog))
exe.run(pserver_startup)
exe.run(pserver_prog)
elif training_role == "TRAINER":
loggin.info("distributed: trainer started")
trainer_prog = t.get_trainer_program()
with open('trainer_prog.desc', 'w') as f:
f.write(str(trainer_prog))
train_loop(exe, train_prog, startup_prog, dev_count, sum_cost,
avg_cost, token_num, predict, pyreader)
else:
print("environment var TRAINER_ROLE should be TRAINER os PSERVER")
logging.critical("environment var TRAINER_ROLE should be TRAINER os PSERVER")
exit(1)
if __name__ == "__main__":
LOG_FORMAT = "[%(asctime)s %(levelname)s %(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(stream=sys.stdout, level=logging.DEBUG, format=LOG_FORMAT)
args = parse_args()
train(args)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册