提交 93ca734c 编写于 作者: T tangwei12

rewrite cluster_trian.sh

上级 fc212827
#!/bin/bash
# start pserver0
python train.py \
--train_data_path /paddle/data/train.txt \
--is_local 0 \
--role pserver \
--endpoints 127.0.0.1:6000,127.0.0.1:6001 \
--current_endpoint 127.0.0.1:6000 \
--trainers 2 \
> pserver0.log 2>&1 &
# start pserver1
python train.py \
--train_data_path /paddle/data/train.txt \
--is_local 0 \
--role pserver \
--endpoints 127.0.0.1:6000,127.0.0.1:6001 \
--current_endpoint 127.0.0.1:6001 \
--trainers 2 \
> pserver1.log 2>&1 &
# start trainer0
python train.py \
--train_data_path /paddle/data/train.txt \
--is_local 0 \
--role trainer \
--endpoints 127.0.0.1:6000,127.0.0.1:6001 \
--trainers 2 \
--trainer_id 0 \
> trainer0.log 2>&1 &
# start trainer1
python train.py \
--train_data_path /paddle/data/train.txt \
--is_local 0 \
--role trainer \
--endpoints 127.0.0.1:6000,127.0.0.1:6001 \
--trainers 2 \
--trainer_id 1 \
> trainer1.log 2>&1 &
echo "WARNING: This script only for run PaddlePaddle Fluid on one node..."
echo ""
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/lib
export PADDLE_PSERVER_PORTS=36001,36002
export PADDLE_PSERVER_PORT_ARRAY=(36001 36002)
export PADDLE_PSERVERS=2
export PADDLE_IP=127.0.0.1
export PADDLE_TRAINERS=2
export CPU_NUM=2
export NUM_THREADS=2
export PADDLE_SYNC_MODE=TRUE
export PADDLE_IS_LOCAL=0
export FLAGS_rpc_deadline=3000000
export GLOG_logtostderr=1
export TRAIN_DATA=data/enwik8
export DICT_PATH=data/enwik8_dict
export IS_SPARSE="--is_sparse"
echo "Start PSERVER ..."
for((i=0;i<$PADDLE_PSERVERS;i++))
do
cur_port=${PADDLE_PSERVER_PORT_ARRAY[$i]}
echo "PADDLE WILL START PSERVER "$cur_port
GLOG_v=0 PADDLE_TRAINING_ROLE=PSERVER CUR_PORT=$cur_port PADDLE_TRAINER_ID=$i python -u train.py $IS_SPARSE &> pserver.$i.log &
done
echo "Start TRAINER ..."
for((i=0;i<$PADDLE_TRAINERS;i++))
do
echo "PADDLE WILL START Trainer "$i
GLOG_v=0 PADDLE_TRAINER_ID=$i PADDLE_TRAINING_ROLE=TRAINER python -u train.py $IS_SPARSE --train_data_path $TRAIN_DATA --dict_path $DICT_PATH &> trainer.$i.log &
done
\ No newline at end of file
......@@ -22,7 +22,7 @@ logger.setLevel(logging.INFO)
def parse_args():
parser = argparse.ArgumentParser(description="PaddlePaddle CTR example")
parser = argparse.ArgumentParser(description="PaddlePaddle Word2vec example")
parser.add_argument(
'--train_data_path',
type=str,
......@@ -85,38 +85,6 @@ def parse_args():
default=False,
help='embedding and nce will use sparse or not, (default: False)')
parser.add_argument(
'--is_local',
type=int,
default=1,
help='Local train or distributed train (default: 1)')
# the following arguments is used for distributed train, if is_local == false, then you should set them
parser.add_argument(
'--role',
type=str,
default='pserver', # trainer or pserver
help='The training role (trainer|pserver) (default: pserver)')
parser.add_argument(
'--endpoints',
type=str,
default='127.0.0.1:6000',
help='The pserver endpoints, like: 127.0.0.1:6000,127.0.0.1:6001')
parser.add_argument(
'--current_endpoint',
type=str,
default='127.0.0.1:6000',
help='The current pserver endpoint (default: 127.0.0.1:6000)')
parser.add_argument(
'--trainer_id',
type=int,
default=0,
help='The current trainer id (default: 0)')
parser.add_argument(
'--trainers',
type=int,
default=1,
help='The num of trianers, (default: 1)')
return parser.parse_args()
......@@ -216,27 +184,51 @@ def train():
optimizer = fluid.optimizer.Adam(learning_rate=1e-3)
optimizer.minimize(loss)
if args.is_local:
if os.environ["PADDLE_IS_LOCAL"] == "1":
logger.info("run local training")
main_program = fluid.default_main_program()
train_loop(args, main_program, word2vec_reader, py_reader, loss, 0)
else:
logger.info("run dist training")
t = fluid.DistributeTranspiler()
t.transpile(
args.trainer_id, pservers=args.endpoints, trainers=args.trainers)
if args.role == "pserver":
trainer_id = int(os.environ["PADDLE_TRAINER_ID"])
trainers = int(os.environ["PADDLE_TRAINERS"])
training_role = os.environ["PADDLE_TRAINING_ROLE"]
ports = os.getenv("PADDLE_PSERVER_PORTS", "6174")
pserver_ip = os.getenv("PADDLE_IP", "")
eplist = []
for port in ports.split(","):
eplist.append(':'.join([pserver_ip, port]))
pserver_endpoints = ",".join(eplist)
current_endpoint = pserver_ip + ":" + os.getenv("CUR_PORT", "2333")
config = fluid.DistributeTranspilerConfig()
config.slice_var_up = False
t = fluid.DistributeTranspiler(config=config)
t.transpile(trainer_id, pservers=pserver_endpoints, trainers=trainers, sync_mode=True)
if training_role == "PSERVER":
logger.info("run pserver")
prog = t.get_pserver_program(args.current_endpoint)
startup = t.get_startup_program(
args.current_endpoint, pserver_program=prog)
prog = t.get_pserver_program(current_endpoint)
startup = t.get_startup_program(current_endpoint, pserver_program=prog)
with open("pserver.main.proto.{}".format(os.getenv("CUR_PORT")), "w") as f:
f.write(str(prog))
exe = fluid.Executor(fluid.CPUPlace())
exe.run(startup)
exe.run(prog)
elif args.role == "trainer":
elif training_role == "TRAINER":
logger.info("run trainer")
train_prog = t.get_trainer_program()
train_loop(args, train_prog, word2vec_reader, py_reader, loss, args.trainer_id)
with open("trainer.main.proto.{}".format(trainer_id), "w") as f:
f.write(str(train_prog))
train_loop(args, train_prog, word2vec_reader, py_reader, loss, trainer_id)
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册