未验证 提交 711aa19b 编写于 作者: W WangXi 提交者: GitHub

fix bert benchmark fleet compiled (#5117)

上级 d7009805
...@@ -172,13 +172,27 @@ def reset_program_state_dict(model, state_dict): ...@@ -172,13 +172,27 @@ def reset_program_state_dict(model, state_dict):
loc=0.0, scale=scale, size=p.shape).astype(dtype_str) loc=0.0, scale=scale, size=p.shape).astype(dtype_str)
return new_state_dict return new_state_dict
def create_strategy():
def build_compiled_program(main_program, loss): """
Create build strategy and exec strategy.
Args:
Returns:
build_strategy: build strategy
exec_strategy: exec strategy
"""
build_strategy = paddle.static.BuildStrategy()
exec_strategy = paddle.static.ExecutionStrategy() exec_strategy = paddle.static.ExecutionStrategy()
build_strategy.enable_addto = args.enable_addto
exec_strategy.num_threads = 1 exec_strategy.num_threads = 1
exec_strategy.num_iteration_per_drop_scope = 10000 exec_strategy.num_iteration_per_drop_scope = 10000
build_strategy = paddle.static.BuildStrategy() return build_strategy, exec_strategy
build_strategy.enable_addto = args.enable_addto
def build_compiled_program(main_program, loss):
build_strategy, exec_strategy = create_strategy()
main_program = paddle.static.CompiledProgram( main_program = paddle.static.CompiledProgram(
main_program).with_data_parallel( main_program).with_data_parallel(
loss_name=loss.name, loss_name=loss.name,
...@@ -187,6 +201,33 @@ def build_compiled_program(main_program, loss): ...@@ -187,6 +201,33 @@ def build_compiled_program(main_program, loss):
return main_program return main_program
def dist_optimizer(args, optimizer):
"""
Create a distributed optimizer based on a normal optimizer
Args:
args:
optimizer: a normal optimizer
Returns:
optimizer: a distributed optimizer
"""
build_strategy, exec_strategy = create_strategy()
dist_strategy = fleet.DistributedStrategy()
dist_strategy.execution_strategy = exec_strategy
dist_strategy.build_strategy = build_strategy
dist_strategy.fuse_grad_size_in_MB = 16
if args.use_amp:
dist_strategy.amp = True
dist_strategy.amp_configs = {
'custom_white_list': ['softmax', 'layer_norm', 'gelu'],
'init_loss_scaling': args.scale_loss,
}
optimizer = fleet.distributed_optimizer(optimizer, strategy=dist_strategy)
return optimizer
def set_seed(seed): def set_seed(seed):
random.seed(seed) random.seed(seed)
np.random.seed(seed) np.random.seed(seed)
...@@ -208,9 +249,12 @@ def do_train(args): ...@@ -208,9 +249,12 @@ def do_train(args):
place = paddle.set_device(args.select_device) place = paddle.set_device(args.select_device)
fleet.init(is_collective=True) fleet.init(is_collective=True)
worker_num = fleet.worker_num()
worker_index = fleet.worker_index()
# Create the random seed for the worker # Create the random seed for the worker
set_seed(args.seed) set_seed(args.seed)
worker_init = WorkerInitObj(args.seed + fleet.worker_index()) worker_init = WorkerInitObj(args.seed + worker_index)
# Define the input data in the static mode # Define the input data in the static mode
main_program = paddle.static.default_main_program() main_program = paddle.static.default_main_program()
...@@ -260,7 +304,7 @@ def do_train(args): ...@@ -260,7 +304,7 @@ def do_train(args):
p.name for n, p in model.named_parameters() p.name for n, p in model.named_parameters()
if not any(nd in n for nd in ["bias", "norm"]) if not any(nd in n for nd in ["bias", "norm"])
]) ])
if args.use_amp: if worker_num == 1 and args.use_amp:
amp_list = paddle.fluid.contrib.mixed_precision.AutoMixedPrecisionLists( amp_list = paddle.fluid.contrib.mixed_precision.AutoMixedPrecisionLists(
custom_white_list=['softmax', 'layer_norm', 'gelu']) custom_white_list=['softmax', 'layer_norm', 'gelu'])
optimizer = paddle.fluid.contrib.mixed_precision.decorate( optimizer = paddle.fluid.contrib.mixed_precision.decorate(
...@@ -268,9 +312,10 @@ def do_train(args): ...@@ -268,9 +312,10 @@ def do_train(args):
amp_list, amp_list,
init_loss_scaling=args.scale_loss, init_loss_scaling=args.scale_loss,
use_dynamic_loss_scaling=True) use_dynamic_loss_scaling=True)
# Use the fleet api to compile the distributed optimizer
strategy = fleet.DistributedStrategy() if worker_num > 1:
optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy) # Use the fleet api to compile the distributed optimizer
optimizer = dist_optimizer(args, optimizer)
optimizer.minimize(loss) optimizer.minimize(loss)
# Define the Executor for running the static model # Define the Executor for running the static model
...@@ -281,14 +326,14 @@ def do_train(args): ...@@ -281,14 +326,14 @@ def do_train(args):
# Use the state dict to update the parameter # Use the state dict to update the parameter
reset_state_dict = reset_program_state_dict(model, state_dict) reset_state_dict = reset_program_state_dict(model, state_dict)
paddle.static.set_program_state(main_program, reset_state_dict) paddle.static.set_program_state(main_program, reset_state_dict)
# Construct the compiled program
main_program = build_compiled_program(main_program, loss) if worker_num == 1:
# Construct the compiled program
main_program = build_compiled_program(main_program, loss)
pool = ThreadPoolExecutor(1) pool = ThreadPoolExecutor(1)
global_step = 0 global_step = 0
tic_train = time.time() tic_train = time.time()
worker_num = fleet.worker_num()
worker_index = fleet.worker_index()
epoch = 0 epoch = 0
while True: while True:
files = [ files = [
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册