未验证 提交 711aa19b 编写于 作者: W WangXi 提交者: GitHub

fix bert benchmark fleet compiled (#5117)

上级 d7009805
......@@ -172,13 +172,27 @@ def reset_program_state_dict(model, state_dict):
loc=0.0, scale=scale, size=p.shape).astype(dtype_str)
return new_state_dict
def build_compiled_program(main_program, loss):
def create_strategy():
"""
Create build strategy and exec strategy.
Args:
Returns:
build_strategy: build strategy
exec_strategy: exec strategy
"""
build_strategy = paddle.static.BuildStrategy()
exec_strategy = paddle.static.ExecutionStrategy()
build_strategy.enable_addto = args.enable_addto
exec_strategy.num_threads = 1
exec_strategy.num_iteration_per_drop_scope = 10000
build_strategy = paddle.static.BuildStrategy()
build_strategy.enable_addto = args.enable_addto
return build_strategy, exec_strategy
def build_compiled_program(main_program, loss):
build_strategy, exec_strategy = create_strategy()
main_program = paddle.static.CompiledProgram(
main_program).with_data_parallel(
loss_name=loss.name,
......@@ -187,6 +201,33 @@ def build_compiled_program(main_program, loss):
return main_program
def dist_optimizer(args, optimizer):
"""
Create a distributed optimizer based on a normal optimizer
Args:
args:
optimizer: a normal optimizer
Returns:
optimizer: a distributed optimizer
"""
build_strategy, exec_strategy = create_strategy()
dist_strategy = fleet.DistributedStrategy()
dist_strategy.execution_strategy = exec_strategy
dist_strategy.build_strategy = build_strategy
dist_strategy.fuse_grad_size_in_MB = 16
if args.use_amp:
dist_strategy.amp = True
dist_strategy.amp_configs = {
'custom_white_list': ['softmax', 'layer_norm', 'gelu'],
'init_loss_scaling': args.scale_loss,
}
optimizer = fleet.distributed_optimizer(optimizer, strategy=dist_strategy)
return optimizer
def set_seed(seed):
random.seed(seed)
np.random.seed(seed)
......@@ -208,9 +249,12 @@ def do_train(args):
place = paddle.set_device(args.select_device)
fleet.init(is_collective=True)
worker_num = fleet.worker_num()
worker_index = fleet.worker_index()
# Create the random seed for the worker
set_seed(args.seed)
worker_init = WorkerInitObj(args.seed + fleet.worker_index())
worker_init = WorkerInitObj(args.seed + worker_index)
# Define the input data in the static mode
main_program = paddle.static.default_main_program()
......@@ -260,7 +304,7 @@ def do_train(args):
p.name for n, p in model.named_parameters()
if not any(nd in n for nd in ["bias", "norm"])
])
if args.use_amp:
if worker_num == 1 and args.use_amp:
amp_list = paddle.fluid.contrib.mixed_precision.AutoMixedPrecisionLists(
custom_white_list=['softmax', 'layer_norm', 'gelu'])
optimizer = paddle.fluid.contrib.mixed_precision.decorate(
......@@ -268,9 +312,10 @@ def do_train(args):
amp_list,
init_loss_scaling=args.scale_loss,
use_dynamic_loss_scaling=True)
# Use the fleet api to compile the distributed optimizer
strategy = fleet.DistributedStrategy()
optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy)
if worker_num > 1:
# Use the fleet api to compile the distributed optimizer
optimizer = dist_optimizer(args, optimizer)
optimizer.minimize(loss)
# Define the Executor for running the static model
......@@ -281,14 +326,14 @@ def do_train(args):
# Use the state dict to update the parameter
reset_state_dict = reset_program_state_dict(model, state_dict)
paddle.static.set_program_state(main_program, reset_state_dict)
# Construct the compiled program
main_program = build_compiled_program(main_program, loss)
if worker_num == 1:
# Construct the compiled program
main_program = build_compiled_program(main_program, loss)
pool = ThreadPoolExecutor(1)
global_step = 0
tic_train = time.time()
worker_num = fleet.worker_num()
worker_index = fleet.worker_index()
epoch = 0
while True:
files = [
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册