“069ff14756f798f52f4af746588014b93d01f839”上不存在“...paddle/fluid/tests/unittests/test_dist_mnist_fleetapi.py”
未验证 提交 77f6c273 编写于 作者: G gongweibao 提交者: GitHub

Cleanup transformer train code! (#1392)

上级 514fab8b
#!/bin/bash
set -x
unset http_proxy
unset https_proxy
#pserver
export TRAINING_ROLE=PSERVER
export PADDLE_PORT=30134
export PADDLE_PSERVERS=127.0.0.1
export PADDLE_IS_LOCAL=0
export PADDLE_INIT_TRAINER_COUNT=1
export POD_IP=127.0.0.1
export PADDLE_TRAINER_ID=0
export PADDLE_TRAINERS_NUM=1
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/lib64/:/usr/local/lib/:/workspace/brpc
export PYTHONPATH=$PYTHONPATH:/paddle/build/build_reader_RelWithDebInfo_gpu/python
#GLOG_v=7 GLOG_logtostderr=1
CUDA_VISIBLE_DEVICES=4,5,6,7 python -u train.py \
--src_vocab_fpath 'cluster_test_data_en_fr/thirdparty/vocab.wordpiece.en-fr' \
--trg_vocab_fpath 'cluster_test_data_en_fr/thirdparty/vocab.wordpiece.en-fr' \
--special_token '<s>' '<e>' '<unk>' \
--token_delimiter '\x01' \
--train_file_pattern 'cluster_test_data_en_fr/train/train.wordpiece.en-fr.0' \
--val_file_pattern 'cluster_test_data_en_fr/thirdparty/newstest2014.wordpiece.en-fr' \
--use_token_batch True \
--batch_size 3200 \
--sort_type pool \
--pool_size 200000 \
--local False > pserver.log 2>&1 &
pserver_pid=$(echo $!)
echo $pserver_pid
sleep 30s
#trainer
export TRAINING_ROLE=TRAINER
export PADDLE_PORT=30134
export PADDLE_PSERVERS=127.0.0.1
export PADDLE_IS_LOCAL=0
export PADDLE_INIT_TRAINER_COUNT=1
export POD_IP=127.0.0.1
export PADDLE_TRAINER_ID=0
export PADDLE_TRAINERS_NUM=1
CUDA_VISIBLE_DEVICES=4,5,6,7 python -u train.py \
--src_vocab_fpath 'cluster_test_data_en_fr/thirdparty/vocab.wordpiece.en-fr' \
--trg_vocab_fpath 'cluster_test_data_en_fr/thirdparty/vocab.wordpiece.en-fr' \
--special_token '<s>' '<e>' '<unk>' \
--token_delimiter '\x01' \
--train_file_pattern 'cluster_test_data_en_fr/train/train.wordpiece.en-fr.0' \
--val_file_pattern 'cluster_test_data_en_fr/thirdparty/newstest2014.wordpiece.en-fr' \
--use_token_batch True \
--batch_size 3200 \
--sort_type pool \
--pool_size 200000 \
--local False > trainer.log 2>&1 &
#sleep 80
#kill -9 $pserver_pid
...@@ -643,7 +643,7 @@ def train(args): ...@@ -643,7 +643,7 @@ def train(args):
if args.sync: if args.sync:
lr_decay = fluid.layers.learning_rate_scheduler.noam_decay( lr_decay = fluid.layers.learning_rate_scheduler.noam_decay(
ModelHyperParams.d_model, TrainTaskConfig.warmup_steps) ModelHyperParams.d_model, TrainTaskConfig.warmup_steps)
print("before adam") logging.info("before adam")
with fluid.default_main_program()._lr_schedule_guard(): with fluid.default_main_program()._lr_schedule_guard():
learning_rate = lr_decay * TrainTaskConfig.learning_rate learning_rate = lr_decay * TrainTaskConfig.learning_rate
...@@ -661,7 +661,7 @@ def train(args): ...@@ -661,7 +661,7 @@ def train(args):
fluid.memory_optimize(train_prog) fluid.memory_optimize(train_prog)
if args.local: if args.local:
print("local start_up:") logging.info("local start_up:")
train_loop(exe, train_prog, startup_prog, dev_count, sum_cost, avg_cost, train_loop(exe, train_prog, startup_prog, dev_count, sum_cost, avg_cost,
token_num, predict, pyreader) token_num, predict, pyreader)
else: else:
...@@ -677,9 +677,9 @@ def train(args): ...@@ -677,9 +677,9 @@ def train(args):
if trainer_id == 0: if trainer_id == 0:
logging.info("train_id == 0, sleep 60s") logging.info("train_id == 0, sleep 60s")
time.sleep(60) time.sleep(60)
print("trainers_num:", trainers_num) logging.info("trainers_num:{}".format(trainers_num))
print("worker_endpoints:", worker_endpoints) logging.info("worker_endpoints:{}".format(worker_endpoints))
print("current_endpoint:", current_endpoint) logging.info("current_endpoint:{}".format(current_endpoint))
append_nccl2_prepare(trainer_id, worker_endpoints, current_endpoint) append_nccl2_prepare(trainer_id, worker_endpoints, current_endpoint)
train_loop(exe, train_loop(exe,
fluid.default_main_program(), dev_count, sum_cost, fluid.default_main_program(), dev_count, sum_cost,
...@@ -696,11 +696,11 @@ def train(args): ...@@ -696,11 +696,11 @@ def train(args):
current_endpoint = os.getenv("POD_IP") + ":" + port current_endpoint = os.getenv("POD_IP") + ":" + port
trainer_id = int(os.getenv("PADDLE_TRAINER_ID")) trainer_id = int(os.getenv("PADDLE_TRAINER_ID"))
print("pserver_endpoints", pserver_endpoints) logging.info("pserver_endpoints:{}".format(pserver_endpoints))
print("current_endpoint", current_endpoint) logging.info("current_endpoint:{}".format(current_endpoint))
print("trainer_id", trainer_id) logging.info("trainer_id:{}".format(trainer_id))
print("pserver_ips", pserver_ips) logging.info("pserver_ips:{}".format(pserver_ips))
print("port", port) logging.info("port:{}".format(port))
t = fluid.DistributeTranspiler() t = fluid.DistributeTranspiler()
t.transpile( t.transpile(
...@@ -715,30 +715,17 @@ def train(args): ...@@ -715,30 +715,17 @@ def train(args):
current_endpoint = os.getenv("POD_IP") + ":" + os.getenv( current_endpoint = os.getenv("POD_IP") + ":" + os.getenv(
"PADDLE_PORT") "PADDLE_PORT")
if not current_endpoint: if not current_endpoint:
print("need env SERVER_ENDPOINT") logging.critical("need env SERVER_ENDPOINT")
exit(1) exit(1)
pserver_prog = t.get_pserver_program(current_endpoint) pserver_prog = t.get_pserver_program(current_endpoint)
pserver_startup = t.get_startup_program(current_endpoint, pserver_startup = t.get_startup_program(current_endpoint,
pserver_prog) pserver_prog)
print("pserver start:")
program_to_code(pserver_startup)
print("pserver train:")
program_to_code(pserver_prog)
#sys.exit(0)
exe.run(pserver_startup) exe.run(pserver_startup)
exe.run(pserver_prog) exe.run(pserver_prog)
elif training_role == "TRAINER": elif training_role == "TRAINER":
logging.info("distributed: trainer started") logging.info("distributed: trainer started")
trainer_prog = t.get_trainer_program() trainer_prog = t.get_trainer_program()
'''
print("trainer start:")
program_to_code(pserver_startup)
print("trainer train:")
program_to_code(trainer_prog)
sys.exit(0)
'''
train_loop(exe, train_prog, startup_prog, dev_count, sum_cost, train_loop(exe, train_prog, startup_prog, dev_count, sum_cost,
avg_cost, token_num, predict, pyreader) avg_cost, token_num, predict, pyreader)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册