diff --git a/PaddleNLP/benchmark/bert/run_pretrain.py b/PaddleNLP/benchmark/bert/run_pretrain.py index ce5e2d1eb86cd48f025d7780d1c7ef3b32cef50d..c81c4bbc15fc8dacc1badbfa84788dc2451d9e4e 100644 --- a/PaddleNLP/benchmark/bert/run_pretrain.py +++ b/PaddleNLP/benchmark/bert/run_pretrain.py @@ -172,13 +172,27 @@ def reset_program_state_dict(model, state_dict): loc=0.0, scale=scale, size=p.shape).astype(dtype_str) return new_state_dict - -def build_compiled_program(main_program, loss): +def create_strategy(): + """ + Create build strategy and exec strategy. + Args: + + Returns: + build_strategy: build strategy + exec_strategy: exec strategy + """ + build_strategy = paddle.static.BuildStrategy() exec_strategy = paddle.static.ExecutionStrategy() + + build_strategy.enable_addto = args.enable_addto + exec_strategy.num_threads = 1 exec_strategy.num_iteration_per_drop_scope = 10000 - build_strategy = paddle.static.BuildStrategy() - build_strategy.enable_addto = args.enable_addto + return build_strategy, exec_strategy + + +def build_compiled_program(main_program, loss): + build_strategy, exec_strategy = create_strategy() main_program = paddle.static.CompiledProgram( main_program).with_data_parallel( loss_name=loss.name, @@ -187,6 +201,33 @@ def build_compiled_program(main_program, loss): return main_program +def dist_optimizer(args, optimizer): + """ + Create a distributed optimizer based on a normal optimizer + Args: + args: + optimizer: a normal optimizer + Returns: + optimizer: a distributed optimizer + """ + build_strategy, exec_strategy = create_strategy() + + dist_strategy = fleet.DistributedStrategy() + dist_strategy.execution_strategy = exec_strategy + dist_strategy.build_strategy = build_strategy + + dist_strategy.fuse_grad_size_in_MB = 16 + if args.use_amp: + dist_strategy.amp = True + dist_strategy.amp_configs = { + 'custom_white_list': ['softmax', 'layer_norm', 'gelu'], + 'init_loss_scaling': args.scale_loss, + } + + optimizer = fleet.distributed_optimizer(optimizer, strategy=dist_strategy) + return optimizer + + def set_seed(seed): random.seed(seed) np.random.seed(seed) @@ -208,9 +249,12 @@ def do_train(args): place = paddle.set_device(args.select_device) fleet.init(is_collective=True) + worker_num = fleet.worker_num() + worker_index = fleet.worker_index() + # Create the random seed for the worker set_seed(args.seed) - worker_init = WorkerInitObj(args.seed + fleet.worker_index()) + worker_init = WorkerInitObj(args.seed + worker_index) # Define the input data in the static mode main_program = paddle.static.default_main_program() @@ -260,7 +304,7 @@ def do_train(args): p.name for n, p in model.named_parameters() if not any(nd in n for nd in ["bias", "norm"]) ]) - if args.use_amp: + if worker_num == 1 and args.use_amp: amp_list = paddle.fluid.contrib.mixed_precision.AutoMixedPrecisionLists( custom_white_list=['softmax', 'layer_norm', 'gelu']) optimizer = paddle.fluid.contrib.mixed_precision.decorate( @@ -268,9 +312,10 @@ def do_train(args): amp_list, init_loss_scaling=args.scale_loss, use_dynamic_loss_scaling=True) - # Use the fleet api to compile the distributed optimizer - strategy = fleet.DistributedStrategy() - optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy) + + if worker_num > 1: + # Use the fleet api to compile the distributed optimizer + optimizer = dist_optimizer(args, optimizer) optimizer.minimize(loss) # Define the Executor for running the static model @@ -281,14 +326,14 @@ def do_train(args): # Use the state dict to update the parameter reset_state_dict = reset_program_state_dict(model, state_dict) paddle.static.set_program_state(main_program, reset_state_dict) - # Construct the compiled program - main_program = build_compiled_program(main_program, loss) + + if worker_num == 1: + # Construct the compiled program + main_program = build_compiled_program(main_program, loss) pool = ThreadPoolExecutor(1) global_step = 0 tic_train = time.time() - worker_num = fleet.worker_num() - worker_index = fleet.worker_index() epoch = 0 while True: files = [