提交 db05813b 编写于 作者: L LielinJiang 提交者: Zeyu Chen

Add multiple process train and mix precision train (#60)

* fp16_err

* add fp16 and multiple gpus train

* add fp16 and multiple gpus train

* add fp16 and multiple gpus train

* add fp16 and multiple gpus train

* add fp16 and multiple gpus train

* add fp16 and multiple gpus train

* add fp16 and multiple gpus train
上级 fca76c3a
......@@ -61,6 +61,7 @@ PaddleSeg支持多进程IO、多卡并行、跨卡Batch Norm同步等训练加
* [PaddleSeg的数据增强](./docs/data_aug.md)
* [特色垂类模型使用](./contrib)
* [多进程训练和混合精度训练](./docs/multiple_gpus_train_and_mixed_precision_train.md)
</br>
......
# PaddleSeg 多进程训练和混合精度训练
### 环境要求
* PaddlePaddle >= 1.6.0
* NVIDIA NCCL >= 2.4.7,并在Linux环境下运行
环境配置,数据,预训练模型准备等工作请参考[安装说明](./installation.md)[PaddleSeg使用说明](./usage.md)
### 多进程训练示例
多进程训练,可以按如下方式启动
```
export CUDA_VISIBLE_DEVICES=0,1
python -m paddle.distributed.launch pdseg/train.py --use_gpu \
--do_eval \
--cfg configs/unet_pet.yaml \
BATCH_SIZE 4 \
TRAIN.PRETRAINED_MODEL_DIR pretrained_model/unet_bn_coco \
SOLVER.LR 5e-5
```
### 混合精度训练示例
启动混合精度训练,只需将```MODEL.FP16```设置为```True```,具体命令如下
```
export CUDA_VISIBLE_DEVICES=0,1
python -m paddle.distributed.launch pdseg/train.py --use_gpu \
--do_eval \
--cfg configs/unet_pet.yaml \
BATCH_SIZE 4 \
TRAIN.PRETRAINED_MODEL_DIR pretrained_model/unet_bn_coco \
SOLVER.LR 5e-5 \
MODEL.FP16 True
```
这时候会采用动态scale的方式,若想使用静态scale的方式,可通过```MODEL.SCALE_LOSS```设置,具体命令如下
```
export CUDA_VISIBLE_DEVICES=0,1
python -m paddle.distributed.launch pdseg/train.py --use_gpu \
--do_eval \
--cfg configs/unet_pet.yaml \
BATCH_SIZE 8 \
TRAIN.PRETRAINED_MODEL_DIR pretrained_model/unet_bn_coco \
SOLVER.LR 5e-5 \
MODEL.FP16 True \
MODEL.SCALE_LOSS 512.0
```
### benchmark
| 模型 | 数据集合 | batch size | number gpu cards | 多进程训练 | 混合精度训练 | 显存占用 | 速度(image/s) | mIoU on val |
|---|---|---|---|---|---|---|---|---|
| DeepLabv3+/Xception65/bn | Cityscapes | 16 | 4 | False | False | 15988 MiB | 17.27 | 79.20 |
| DeepLabv3+/Xception65/bn | Cityscapes | 16 | 4 | True | False | 15814 MiB | 19.80 | 78.90 |
| DeepLabv3+/Xception65/bn | Cityscapes | 16 | 4 | True | True | 14922 MiB | 25.84 |79.06|
### 参考
- [Mixed Precision Training](https://arxiv.org/abs/1710.03740)
......@@ -29,6 +29,7 @@ def softmax_with_loss(logit, label, ignore_mask=None, num_classes=2):
label = fluid.layers.reshape(label, [-1, 1])
label = fluid.layers.cast(label, 'int64')
ignore_mask = fluid.layers.reshape(ignore_mask, [-1, 1])
loss, probs = fluid.layers.softmax_with_cross_entropy(
logit,
label,
......@@ -36,14 +37,8 @@ def softmax_with_loss(logit, label, ignore_mask=None, num_classes=2):
return_softmax=True)
loss = loss * ignore_mask
if cfg.MODEL.FP16:
loss = fluid.layers.cast(loss, 'float32')
avg_loss = fluid.layers.mean(loss) / fluid.layers.mean(ignore_mask)
avg_loss = fluid.layers.cast(avg_loss, 'float16')
else:
avg_loss = fluid.layers.mean(loss) / fluid.layers.mean(ignore_mask)
if cfg.MODEL.SCALE_LOSS > 1.0:
avg_loss = avg_loss * cfg.MODEL.SCALE_LOSS
avg_loss = fluid.layers.mean(loss) / fluid.layers.mean(ignore_mask)
label.stop_gradient = True
ignore_mask.stop_gradient = True
return avg_loss
......
......@@ -228,7 +228,7 @@ class MobileNetV2():
num_groups=num_expfilter,
if_act=True,
name=name + '_dwise',
use_cudnn=True if cfg.MODEL.FP16 else False)
use_cudnn=False)
depthwise_output = bottleneck_conv
......
......@@ -149,7 +149,7 @@ def separate_conv(input, channel, stride, filter, dilation=1, act=None):
groups=input.shape[1],
padding=(filter // 2) * dilation,
dilation=dilation,
use_cudnn=True if cfg.MODEL.FP16 else False,
use_cudnn=False,
param_attr=param_attr)
input = bn(input)
if act: input = act(input)
......
......@@ -153,8 +153,7 @@ def build_model(main_prog, start_prog, phase=ModelPhase.TRAIN):
capacity=cfg.DATALOADER.BUF_SIZE,
iterable=False,
use_double_buffer=True)
if cfg.MODEL.FP16:
image = fluid.layers.cast(image, "float16")
model_name = map_model_name(cfg.MODEL.MODEL_NAME)
model_func = get_func("modeling." + model_name)
......@@ -203,13 +202,13 @@ def build_model(main_prog, start_prog, phase=ModelPhase.TRAIN):
else:
logit = softmax(logit)
return image, logit
if class_num == 1:
out = sigmoid_to_softmax(logit)
out = fluid.layers.transpose(out, [0, 2, 3, 1])
else:
out = fluid.layers.transpose(logit, [0, 2, 3, 1])
if cfg.MODEL.FP16:
out = fluid.layers.cast(out, 'float32')
pred = fluid.layers.argmax(out, axis=3)
pred = fluid.layers.unsqueeze(pred, axes=[3])
if ModelPhase.is_visual(phase):
......
......@@ -27,7 +27,6 @@ from models.libs.model_libs import separate_conv
from models.backbone.mobilenet_v2 import MobileNetV2 as mobilenet_backbone
from models.backbone.xception import Xception as xception_backbone
def encoder(input):
# 编码器配置,采用ASPP架构,pooling + 1x1_conv + 三个不同尺度的空洞卷积并行, concat后1x1conv
# ASPP_WITH_SEP_CONV:默认为真,使用depthwise可分离卷积,否则使用普通卷积
......@@ -48,13 +47,8 @@ def encoder(input):
with scope('encoder'):
channel = 256
with scope("image_pool"):
if cfg.MODEL.FP16:
image_avg = fluid.layers.reduce_mean(
fluid.layers.cast(input, 'float32'), [2, 3], keep_dim=True)
image_avg = fluid.layers.cast(image_avg, 'float16')
else:
image_avg = fluid.layers.reduce_mean(
input, [2, 3], keep_dim=True)
image_avg = fluid.layers.reduce_mean(
input, [2, 3], keep_dim=True)
image_avg = bn_relu(
conv(
image_avg,
......@@ -64,11 +58,8 @@ def encoder(input):
groups=1,
padding=0,
param_attr=param_attr))
if cfg.MODEL.FP16:
image_avg = fluid.layers.cast(image_avg, 'float32')
image_avg = fluid.layers.resize_bilinear(image_avg, input.shape[2:])
if cfg.MODEL.FP16:
image_avg = fluid.layers.cast(image_avg, 'float16')
with scope("aspp0"):
aspp0 = bn_relu(
conv(
......@@ -157,12 +148,9 @@ def decoder(encode_data, decode_shortcut):
groups=1,
padding=0,
param_attr=param_attr))
if cfg.MODEL.FP16:
encode_data = fluid.layers.cast(encode_data, 'float32')
encode_data = fluid.layers.resize_bilinear(
encode_data, decode_shortcut.shape[2:])
if cfg.MODEL.FP16:
encode_data = fluid.layers.cast(encode_data, 'float16')
encode_data = fluid.layers.concat([encode_data, decode_shortcut],
axis=1)
if cfg.MODEL.DEEPLAB.DECODER_USE_SEP_CONV:
......@@ -270,9 +258,6 @@ def deeplabv3p(img, num_classes):
padding=0,
bias_attr=True,
param_attr=param_attr)
if cfg.MODEL.FP16:
logit = fluid.layers.cast(logit, 'float32')
logit = fluid.layers.resize_bilinear(logit, img.shape[2:])
if cfg.MODEL.FP16:
logit = fluid.layers.cast(logit, 'float16')
return logit
......@@ -32,7 +32,7 @@ import data_aug as aug
from utils.config import cfg
from data_utils import GeneratorEnqueuer
from models.model_builder import ModelPhase
import copy
def cv2_imread(file_path, flag=cv2.IMREAD_COLOR):
# resolve cv2.imread open Chinese file path issues on Windows Platform.
......@@ -49,15 +49,25 @@ class SegDataset(object):
self.shuffle = shuffle
self.data_dir = data_dir
self.shuffle_seed = 0
# NOTE: Please ensure file list was save in UTF-8 coding format
with codecs.open(file_list, 'r', 'utf-8') as flist:
self.lines = [line.strip() for line in flist]
if shuffle:
self.all_lines = copy.deepcopy(self.lines)
if shuffle and cfg.NUM_TRAINERS > 1:
np.random.RandomState(self.shuffle_seed).shuffle(self.all_lines)
elif shuffle:
np.random.shuffle(self.lines)
def generator(self):
if self.shuffle:
if self.shuffle and cfg.NUM_TRAINERS > 1:
np.random.RandomState(self.shuffle_seed).shuffle(self.all_lines)
num_lines = len(self.all_lines) // cfg.NUM_TRAINERS
self.lines = self.all_lines[num_lines * cfg.TRAINER_ID: num_lines * (cfg.TRAINER_ID + 1)]
self.shuffle_seed += 1
elif self.shuffle:
np.random.shuffle(self.lines)
for line in self.lines:
yield self.process_image(line, self.data_dir, self.mode)
......@@ -78,8 +88,14 @@ class SegDataset(object):
def multiprocess_generator(self, max_queue_size=32, num_processes=8):
# Re-shuffle file list
if self.shuffle:
if self.shuffle and cfg.NUM_TRAINERS > 1:
np.random.RandomState(self.shuffle_seed).shuffle(self.all_lines)
num_lines = len(self.all_lines) // self.num_trainers
self.lines = self.all_lines[num_lines * self.trainer_id: num_lines * (self.trainer_id + 1)]
self.shuffle_seed += 1
elif self.shuffle:
np.random.shuffle(self.lines)
# Create multiple sharding generators according to num_processes for multiple processes
generators = []
for pid in range(num_processes):
......
......@@ -18,7 +18,7 @@ import paddle.fluid as fluid
import numpy as np
import importlib
from utils.config import cfg
from paddle.fluid.contrib.mixed_precision.fp16_utils import create_master_params_grads, master_param_to_train_param
from paddle.fluid.contrib.mixed_precision.decorator import OptimizerWithMixedPrecison, decorate, AutoMixedPrecisionLists
class Solver(object):
......@@ -74,15 +74,22 @@ class Solver(object):
regularization_coeff=self.weight_decay),
)
if cfg.MODEL.FP16:
params_grads = optimizer.backward(loss, self.start_prog)
master_params_grads = create_master_params_grads(
params_grads, self.main_prog, self.start_prog,
cfg.MODEL.SCALE_LOSS)
optimizer.apply_gradients(master_params_grads)
master_param_to_train_param(master_params_grads, params_grads,
self.main_prog)
else:
optimizer.minimize(loss)
if cfg.MODEL.MODEL_NAME in ["pspnet"]:
custom_black_list = {"pool2d"}
else:
custom_black_list = {}
amp_lists = AutoMixedPrecisionLists(custom_black_list=custom_black_list)
assert isinstance(cfg.MODEL.SCALE_LOSS, float) or isinstance(cfg.MODEL.SCALE_LOSS, str), \
"data type of MODEL.SCALE_LOSS must be float or str"
if isinstance(cfg.MODEL.SCALE_LOSS, float):
optimizer = decorate(optimizer, amp_lists=amp_lists, init_loss_scaling=cfg.MODEL.SCALE_LOSS,
use_dynamic_loss_scaling=False)
else:
assert cfg.MODEL.SCALE_LOSS.lower() in ['dynamic'], "if MODEL.SCALE_LOSS is a string,\
must be set as 'DYNAMIC'!"
optimizer = decorate(optimizer, amp_lists=amp_lists, use_dynamic_loss_scaling=True)
optimizer.minimize(loss)
return decayed_lr
def adam_optimizer(self, lr_policy, loss):
......
......@@ -40,8 +40,7 @@ from models.model_builder import ModelPhase
from models.model_builder import parse_shape_from_file
from eval import evaluate
from vis import visualize
from utils.fp16_utils import load_fp16_vars
from utils import dist_utils
def parse_args():
parser = argparse.ArgumentParser(description='PaddleSeg training')
......@@ -178,6 +177,9 @@ def load_checkpoint(exe, program):
return begin_epoch
def print_info(*msg):
if cfg.TRAINER_ID == 0:
print(*msg)
def train(cfg):
startup_prog = fluid.Program()
......@@ -201,7 +203,7 @@ def train(cfg):
batch_data = []
for b in data_gen:
batch_data.append(b)
if len(batch_data) == cfg.BATCH_SIZE:
if len(batch_data) == (cfg.BATCH_SIZE // cfg.NUM_TRAINERS):
for item in batch_data:
yield item[0], item[1], item[2]
batch_data = []
......@@ -212,11 +214,15 @@ def train(cfg):
yield item[0], item[1], item[2]
# Get device environment
# places = fluid.cuda_places() if args.use_gpu else fluid.cpu_places()
# place = places[0]
gpu_id = int(os.environ.get('FLAGS_selected_gpus', 0))
place = fluid.CUDAPlace(gpu_id) if args.use_gpu else fluid.CPUPlace()
places = fluid.cuda_places() if args.use_gpu else fluid.cpu_places()
place = places[0]
# Get number of GPU
dev_count = len(places)
print("#Device count: {}".format(dev_count))
dev_count = cfg.NUM_TRAINERS if cfg.NUM_TRAINERS > 1 else len(places)
print_info("#Device count: {}".format(dev_count))
# Make sure BATCH_SIZE can divided by GPU cards
assert cfg.BATCH_SIZE % dev_count == 0, (
......@@ -224,7 +230,7 @@ def train(cfg):
cfg.BATCH_SIZE, dev_count))
# If use multi-gpu training mode, batch data will allocated to each GPU evenly
batch_size_per_dev = cfg.BATCH_SIZE // dev_count
print("batch_size_per_dev: {}".format(batch_size_per_dev))
print_info("batch_size_per_dev: {}".format(batch_size_per_dev))
py_reader, avg_loss, lr, pred, grts, masks = build_model(
train_prog, startup_prog, phase=ModelPhase.TRAIN)
......@@ -240,13 +246,18 @@ def train(cfg):
exec_strategy.num_threads = fluid.core.get_cuda_device_count()
exec_strategy.num_iteration_per_drop_scope = 100
build_strategy = fluid.BuildStrategy()
if cfg.NUM_TRAINERS > 1 and args.use_gpu:
dist_utils.prepare_for_multi_process(exe, build_strategy, train_prog)
exec_strategy.num_threads = 1
if cfg.TRAIN.SYNC_BATCH_NORM and args.use_gpu:
if dev_count > 1:
# Apply sync batch norm strategy
print("Sync BatchNorm strategy is effective.")
print_info("Sync BatchNorm strategy is effective.")
build_strategy.sync_batch_norm = True
else:
print("Sync BatchNorm strategy will not be effective if GPU device"
print_info("Sync BatchNorm strategy will not be effective if GPU device"
" count <= 1")
compiled_train_prog = fluid.CompiledProgram(train_prog).with_data_parallel(
loss_name=avg_loss.name,
......@@ -259,7 +270,7 @@ def train(cfg):
begin_epoch = load_checkpoint(exe, train_prog)
# Load pretrained model
elif os.path.exists(cfg.TRAIN.PRETRAINED_MODEL_DIR):
print('Pretrained model dir:', cfg.TRAIN.PRETRAINED_MODEL_DIR)
print_info('Pretrained model dir: ', cfg.TRAIN.PRETRAINED_MODEL_DIR)
load_vars = []
load_fail_vars = []
......@@ -283,22 +294,19 @@ def train(cfg):
load_vars.append(x)
else:
load_fail_vars.append(x)
if cfg.MODEL.FP16:
# If open FP16 training mode, load FP16 var separate
load_fp16_vars(exe, cfg.TRAIN.PRETRAINED_MODEL_DIR, train_prog)
else:
fluid.io.load_vars(
exe, dirname=cfg.TRAIN.PRETRAINED_MODEL_DIR, vars=load_vars)
fluid.io.load_vars(
exe, dirname=cfg.TRAIN.PRETRAINED_MODEL_DIR, vars=load_vars)
for var in load_vars:
print("Parameter[{}] loaded sucessfully!".format(var.name))
print_info("Parameter[{}] loaded sucessfully!".format(var.name))
for var in load_fail_vars:
print("Parameter[{}] don't exist or shape does not match current network, skip"
print_info("Parameter[{}] don't exist or shape does not match current network, skip"
" to load it.".format(var.name))
print("{}/{} pretrained parameters loaded successfully!".format(
print_info("{}/{} pretrained parameters loaded successfully!".format(
len(load_vars),
len(load_vars) + len(load_fail_vars)))
else:
print('Pretrained model dir {} not exists, training from scratch...'.
print_info('Pretrained model dir {} not exists, training from scratch...'.
format(cfg.TRAIN.PRETRAINED_MODEL_DIR))
fetch_list = [avg_loss.name, lr.name]
......@@ -312,12 +320,14 @@ def train(cfg):
if args.use_tb:
if not args.tb_log_dir:
print("Please specify the log directory by --tb_log_dir.")
print_info("Please specify the log directory by --tb_log_dir.")
exit(1)
from tb_paddle import SummaryWriter
log_writer = SummaryWriter(args.tb_log_dir)
# trainer_id = int(os.getenv("PADDLE_TRAINER_ID", 0))
# num_trainers = int(os.environ.get('PADDLE_TRAINERS_NUM', 1))
global_step = 0
all_step = cfg.DATASET.TRAIN_TOTAL_IMAGES // cfg.BATCH_SIZE
if cfg.DATASET.TRAIN_TOTAL_IMAGES % cfg.BATCH_SIZE and drop_last != True:
......@@ -333,9 +343,9 @@ def train(cfg):
begin_epoch, cfg.SOLVER.NUM_EPOCHS))
if args.use_mpio:
print("Use multiprocess reader")
print_info("Use multiprocess reader")
else:
print("Use multi-thread reader")
print_info("Use multi-thread reader")
for epoch in range(begin_epoch, cfg.SOLVER.NUM_EPOCHS + 1):
py_reader.start()
......@@ -348,7 +358,6 @@ def train(cfg):
program=compiled_train_prog,
fetch_list=fetch_list,
return_numpy=True)
cm.calculate(pred, grts, masks)
avg_loss += np.mean(np.array(loss))
global_step += 1
......@@ -359,13 +368,13 @@ def train(cfg):
category_acc, mean_acc = cm.accuracy()
category_iou, mean_iou = cm.mean_iou()
print((
print_info((
"epoch={} step={} lr={:.5f} loss={:.4f} acc={:.5f} mIoU={:.5f} step/sec={:.3f} | ETA {}"
).format(epoch, global_step, lr[0], avg_loss, mean_acc,
mean_iou, speed,
calculate_eta(all_step - global_step, speed)))
print("Category IoU:", category_iou)
print("Category Acc:", category_acc)
print_info("Category IoU: ", category_iou)
print_info("Category Acc: ", category_acc)
if args.use_tb:
log_writer.add_scalar('Train/mean_iou', mean_iou,
global_step)
......@@ -390,7 +399,7 @@ def train(cfg):
avg_loss += np.mean(np.array(loss))
global_step += 1
if global_step % args.log_steps == 0:
if global_step % args.log_steps == 0 and cfg.TRAINER_ID == 0:
avg_loss /= args.log_steps
speed = args.log_steps / timer.elapsed_time()
print((
......@@ -414,7 +423,7 @@ def train(cfg):
except Exception as e:
print(e)
if epoch % cfg.TRAIN.SNAPSHOT_EPOCH == 0:
if epoch % cfg.TRAIN.SNAPSHOT_EPOCH == 0 and cfg.TRAINER_ID == 0:
ckpt_dir = save_checkpoint(exe, train_prog, epoch)
if args.do_eval:
......@@ -441,16 +450,20 @@ def train(cfg):
log_writer=log_writer)
# save final model
save_checkpoint(exe, train_prog, 'final')
if cfg.TRAINER_ID == 0:
save_checkpoint(exe, train_prog, 'final')
def main(args):
if args.cfg_file is not None:
cfg.update_from_file(args.cfg_file)
if args.opts is not None:
cfg.update_from_list(args.opts)
cfg.TRAINER_ID = int(os.getenv("PADDLE_TRAINER_ID", 0))
cfg.NUM_TRAINERS = int(os.environ.get('PADDLE_TRAINERS_NUM', 1))
cfg.check_and_infer(reset_dataset=True)
print(pprint.pformat(cfg))
print_info(pprint.pformat(cfg))
train(cfg)
......
......@@ -139,7 +139,7 @@ class SegConfig(dict):
def update_from_file(self, config_file):
with codecs.open(config_file, 'r', 'utf-8') as file:
dic = yaml.load(file)
dic = yaml.load(file, Loader=yaml.FullLoader)
self.update_from_segconfig(dic)
def set_immutable(self, immutable):
......
......@@ -31,7 +31,10 @@ cfg.BATCH_SIZE = 1
cfg.EVAL_CROP_SIZE = tuple()
# 训练时图像裁剪尺寸(宽,高)
cfg.TRAIN_CROP_SIZE = tuple()
# 多进程训练总进程数
cfg.NUM_TRAINERS = 1
# 多进程训练进程ID
cfg.TRAINER_ID = 0
########################## 数据载入配置 #######################################
# 数据载入时的并发数, 建议值8
cfg.DATALOADER.NUM_WORKERS = 8
......@@ -171,8 +174,8 @@ cfg.MODEL.DEFAULT_EPSILON = 1e-5
cfg.MODEL.BN_MOMENTUM = 0.99
# 是否使用FP16训练
cfg.MODEL.FP16 = False
# FP16需对LOSS进行scale, 一般训练FP16设置为8.0
cfg.MODEL.SCALE_LOSS = 1.0
# 混合精度训练需对LOSS进行scale, 默认为动态scale,静态scale可以设置为512.0
cfg.MODEL.SCALE_LOSS = "DYNAMIC"
########################## DeepLab模型配置 ####################################
# DeepLab backbone 配置, 可选项xception_65, mobilenetv2
......
#copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import paddle.fluid as fluid
def nccl2_prepare(args, startup_prog, main_prog):
config = fluid.DistributeTranspilerConfig()
config.mode = "nccl2"
t = fluid.DistributeTranspiler(config=config)
envs = args.dist_env
t.transpile(
envs["trainer_id"],
trainers=','.join(envs["trainer_endpoints"]),
current_endpoint=envs["current_endpoint"],
startup_program=startup_prog,
program=main_prog)
def pserver_prepare(args, train_prog, startup_prog):
config = fluid.DistributeTranspilerConfig()
config.slice_var_up = args.split_var
t = fluid.DistributeTranspiler(config=config)
envs = args.dist_env
training_role = envs["training_role"]
t.transpile(
envs["trainer_id"],
program=train_prog,
pservers=envs["pserver_endpoints"],
trainers=envs["num_trainers"],
sync_mode=not args.async_mode,
startup_program=startup_prog)
if training_role == "PSERVER":
pserver_program = t.get_pserver_program(envs["current_endpoint"])
pserver_startup_program = t.get_startup_program(
envs["current_endpoint"],
pserver_program,
startup_program=startup_prog)
return pserver_program, pserver_startup_program
elif training_role == "TRAINER":
train_program = t.get_trainer_program()
return train_program, startup_prog
else:
raise ValueError(
'PADDLE_TRAINING_ROLE environment variable must be either TRAINER or PSERVER'
)
def nccl2_prepare_paddle(trainer_id, startup_prog, main_prog):
config = fluid.DistributeTranspilerConfig()
config.mode = "nccl2"
t = fluid.DistributeTranspiler(config=config)
t.transpile(
trainer_id,
trainers=os.environ.get('PADDLE_TRAINER_ENDPOINTS'),
current_endpoint=os.environ.get('PADDLE_CURRENT_ENDPOINT'),
startup_program=startup_prog,
program=main_prog)
def prepare_for_multi_process(exe, build_strategy, train_prog):
# prepare for multi-process
trainer_id = int(os.environ.get('PADDLE_TRAINER_ID', 0))
num_trainers = int(os.environ.get('PADDLE_TRAINERS_NUM', 1))
if num_trainers < 2: return
build_strategy.num_trainers = num_trainers
build_strategy.trainer_id = trainer_id
# NOTE(zcd): use multi processes to train the model,
# and each process use one GPU card.
startup_prog = fluid.Program()
nccl2_prepare_paddle(trainer_id, startup_prog, train_prog)
# the startup_prog are run two times, but it doesn't matter.
exe.run(startup_prog)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册