未验证 提交 969939e7 编写于 作者: L lilong12 提交者: GitHub

Add fleet for transformer benchmark (#5164)

* add fleet, test=develop
上级 85105600
...@@ -96,4 +96,11 @@ dropout: 0.1 ...@@ -96,4 +96,11 @@ dropout: 0.1
# Vocabularies in source and target should be same for weight sharing. # Vocabularies in source and target should be same for weight sharing.
weight_sharing: True weight_sharing: True
# Use amp or not
use_amp: False
scale_loss: 1.0
# Whether to use multi-card/multi-node distributed training.
is_distributed: True
max_iter: None max_iter: None
python -m paddle.distributed.launch \
--gpus="0,1" \
train.py
...@@ -10,6 +10,7 @@ from attrdict import AttrDict ...@@ -10,6 +10,7 @@ from attrdict import AttrDict
from pprint import pprint from pprint import pprint
import paddle import paddle
import paddle.distributed.fleet as fleet
import paddle.distributed as dist import paddle.distributed as dist
from paddlenlp.transformers import TransformerModel, CrossEntropyCriterion from paddlenlp.transformers import TransformerModel, CrossEntropyCriterion
...@@ -36,6 +37,12 @@ def parse_args(): ...@@ -36,6 +37,12 @@ def parse_args():
def do_train(args): def do_train(args):
paddle.enable_static() paddle.enable_static()
if args.is_distributed:
fleet.init(is_collective=True)
gpu_id = int(os.getenv("FLAGS_selected_gpus", "0"))
places = paddle.CUDAPlace(gpu_id) if args.use_gpu else paddle.static.cpu_places()
trainer_count = 1 if args.use_gpu else len(places)
else:
places = paddle.static.cuda_places() if args.use_gpu else paddle.static.cpu_places() places = paddle.static.cuda_places() if args.use_gpu else paddle.static.cpu_places()
trainer_count = len(places) trainer_count = len(places)
...@@ -88,11 +95,28 @@ def do_train(args): ...@@ -88,11 +95,28 @@ def do_train(args):
epsilon=float(args.eps), epsilon=float(args.eps),
parameters=transformer.parameters()) parameters=transformer.parameters())
if args.is_distributed:
build_strategy = paddle.static.BuildStrategy()
exec_strategy = paddle.static.ExecutionStrategy()
dist_strategy = fleet.DistributedStrategy()
dist_strategy.build_strategy = build_strategy
dist_strategy.execution_strategy = exec_strategy
dist_strategy.fuse_grad_size_in_MB = 16
if args.use_amp:
dist_strategy.amp = True
dist_strategy.amp_configs = {
'custom_white_list': ['softmax', 'layer_norm', 'gelu'],
'init_loss_scaling': args.scale_loss,
}
optimizer = fleet.distributed_optimizer(optimizer, strategy=dist_strategy)
optimizer.minimize(avg_cost) optimizer.minimize(avg_cost)
if args.is_distributed:
exe = paddle.static.Executor(places)
else:
exe = paddle.static.Executor() exe = paddle.static.Executor()
exe.run(startup_program)
build_strategy = paddle.static.BuildStrategy() build_strategy = paddle.static.BuildStrategy()
exec_strategy = paddle.static.ExecutionStrategy() exec_strategy = paddle.static.ExecutionStrategy()
...@@ -101,6 +125,8 @@ def do_train(args): ...@@ -101,6 +125,8 @@ def do_train(args):
loss_name=avg_cost.name, loss_name=avg_cost.name,
build_strategy=build_strategy, build_strategy=build_strategy,
exec_strategy=exec_strategy) exec_strategy=exec_strategy)
exe.run(startup_program)
# the best cross-entropy value with label smoothing # the best cross-entropy value with label smoothing
loss_normalizer = -( loss_normalizer = -(
...@@ -127,6 +153,15 @@ def do_train(args): ...@@ -127,6 +153,15 @@ def do_train(args):
data = [data] data = [data]
train_reader_cost = time.time() - batch_start train_reader_cost = time.time() - batch_start
if args.is_distributed:
outs = exe.run(train_program,
feed=[{
'src_word': data[i][0],
'trg_word': data[i][1],
'lbl_word': data[i][2],
} for i in range(trainer_count)],
fetch_list=[sum_cost.name, token_num.name])
else:
outs = exe.run(compiled_train_program, outs = exe.run(compiled_train_program,
feed=[{ feed=[{
'src_word': data[i][0], 'src_word': data[i][0],
...@@ -176,7 +211,7 @@ def do_train(args): ...@@ -176,7 +211,7 @@ def do_train(args):
batch_ips_avg.reset() batch_ips_avg.reset()
if step_idx % args.save_step == 0 and step_idx != 0: if step_idx % args.save_step == 0 and step_idx != 0:
if args.save_model: if args.save_model and dist.get_rank() == 0:
model_path = os.path.join( model_path = os.path.join(
args.save_model, "step_" + str(step_idx), "transformer") args.save_model, "step_" + str(step_idx), "transformer")
paddle.static.save(train_program, model_path) paddle.static.save(train_program, model_path)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册