未验证 提交 c50d1f18 编写于 作者: G gongweibao 提交者: GitHub

Add DGC support (#2031)

上级 cb764459
......@@ -108,3 +108,24 @@ The second figure shows speed-ups when using multiple GPUs according to the abov
Speed-ups of Multiple-GPU Training of Resnet50 on Imagenet
</p>
## Deep Gradient Compression([arXiv:1712.01887](https://arxiv.org/abs/1712.01887)) for resnet
#### Environment
- GPU: NVIDIA® Tesla® V100
- Machine number * Card number: 4 * 4
- System: Centos 6u3
- Cuda/Cudnn: 9.0/7.1
- Dataset: ImageNet
- Date: 2017.04
- PaddleVersion: 1.4
- Batch size: 32
#### Performance
<p align="center">
<img src="../images/resnet_dgc.png" width=528> <br />
Performance using DGC for resnet-fp32 under different bandwidth
</p>
......@@ -68,6 +68,9 @@ def parse_args():
add_arg('reduce_strategy', str, "allreduce", "Choose from reduce or allreduce.")
add_arg('skip_unbalanced_data', bool, False, "Skip data not if data not balanced on nodes.")
add_arg('enable_sequential_execution', bool, False, "Skip data not if data not balanced on nodes.")
#for dgc
add_arg('enable_dgc', bool, False, "Skip data not if data not balanced on nodes.")
add_arg('rampup_begin_step', int, 5008, "Skip data not if data not balanced on nodes.")
# yapf: enable
args = parser.parse_args()
return args
......@@ -157,6 +160,17 @@ def build_program(is_train, main_prog, startup_prog, args):
boundaries=bd, values=lr),
warmup_steps, start_lr, end_lr),
momentum=0.9)
if args.enable_dgc:
optimizer = fluid.optimizer.DGCMomentumOptimizer(
learning_rate=utils.learning_rate.lr_warmup(
fluid.layers.piecewise_decay(
boundaries=bd, values=lr),
warmup_steps, start_lr, end_lr),
momentum=0.9,
sparsity=[0.999, 0.999],
rampup_begin_step=args.rampup_begin_step)
if args.fp16:
params_grads = optimizer.backward(avg_cost)
master_params_grads = utils.create_master_params_grads(
......@@ -224,7 +238,7 @@ def train_parallel(args):
if args.update_method == "pserver":
train_prog, startup_prog = pserver_prepare(args, train_prog, startup_prog)
elif args.update_method == "nccl2":
nccl2_prepare(args, startup_prog)
nccl2_prepare(args, startup_prog, main_prog=train_prog)
if args.dist_env["training_role"] == "PSERVER":
run_pserver(train_prog, startup_prog)
......@@ -247,11 +261,16 @@ def train_parallel(args):
strategy = fluid.ExecutionStrategy()
strategy.num_threads = args.num_threads
# num_iteration_per_drop_scope indicates how
# many iterations to clean up the temp variables which
# is generated during execution. It may make the execution faster,
# because the temp variable's shape maybe the same between two iterations
strategy.num_iteration_per_drop_scope = 30
build_strategy = fluid.BuildStrategy()
build_strategy.enable_inplace = False
build_strategy.memory_optimize = False
build_strategy.enable_sequential_execution = bool(args.enable_sequential_execution)
if args.reduce_strategy == "reduce":
build_strategy.reduce_strategy = fluid.BuildStrategy(
......
......@@ -2,7 +2,7 @@ import os
import paddle.fluid as fluid
def nccl2_prepare(args, startup_prog):
def nccl2_prepare(args, startup_prog, main_prog):
config = fluid.DistributeTranspilerConfig()
config.mode = "nccl2"
t = fluid.DistributeTranspiler(config=config)
......@@ -12,7 +12,8 @@ def nccl2_prepare(args, startup_prog):
t.transpile(envs["trainer_id"],
trainers=','.join(envs["trainer_endpoints"]),
current_endpoint=envs["current_endpoint"],
startup_program=startup_prog)
startup_program=startup_prog,
program=main_prog)
def pserver_prepare(args, train_prog, startup_prog):
......
#!/bin/bash
set -e
enable_dgc=False
while true ; do
case "$1" in
-enable_dgc) enable_dgc="$2" ; shift 2 ;;
*)
if [[ ${#1} > 0 ]]; then
echo "not supported arugments ${1}" ; exit 1 ;
else
break
fi
;;
esac
done
case "${enable_dgc}" in
True) ;;
False) ;;
*) echo "not support argument -enable_dgc: ${dgc}" ; exit 1 ;;
esac
export MODEL="DistResNet"
export PADDLE_TRAINER_ENDPOINTS="127.0.0.1:7160,127.0.0.1:7161"
......@@ -9,16 +31,20 @@ mkdir -p logs
# NOTE: set NCCL_P2P_DISABLE so that can run nccl2 distribute train on one node.
# You can set vlog to see more details' log.
# export GLOG_v=1
# export GLOG_logtostderr=1
PADDLE_TRAINING_ROLE="TRAINER" \
PADDLE_CURRENT_ENDPOINT="127.0.0.1:7160" \
PADDLE_TRAINER_ID="0" \
CUDA_VISIBLE_DEVICES="0" \
NCCL_P2P_DISABLE="1" \
python dist_train.py --model $MODEL --update_method nccl2 --batch_size 32 &> logs/tr0.log &
python -u dist_train.py --enable_dgc ${enable_dgc} --model $MODEL --update_method nccl2 --batch_size 32 &> logs/tr0.log &
PADDLE_TRAINING_ROLE="TRAINER" \
PADDLE_CURRENT_ENDPOINT="127.0.0.1:7161" \
PADDLE_TRAINER_ID="1" \
CUDA_VISIBLE_DEVICES="1" \
NCCL_P2P_DISABLE="1" \
python dist_train.py --model $MODEL --update_method nccl2 --batch_size 32 &> logs/tr1.log &
python -u dist_train.py --enable_dgc ${enable_dgc} --model $MODEL --update_method nccl2 --batch_size 32 &> logs/tr1.log &
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册