提交 8aac9839 编写于 作者: Z Zeyu Chen

Merge branch 'develop' of https://github.com/PaddlePaddle/models into develop

......@@ -13,6 +13,8 @@
# limitations under the License.
import argparse
import collections
import itertools
import os
import random
import time
......@@ -21,6 +23,7 @@ from functools import partial
from concurrent.futures import ThreadPoolExecutor
import numpy as np
import distutils.util
import paddle
import paddle.distributed.fleet as fleet
......@@ -117,6 +120,22 @@ def parse_args():
help="Save checkpoint every X updates steps.")
parser.add_argument(
"--seed", type=int, default=42, help="Random seed for initialization")
parser.add_argument(
"--use_amp",
type=distutils.util.strtobool,
default=False,
help="Enable mixed precision training.")
parser.add_argument(
"--enable_addto",
type=distutils.util.strtobool,
default=False,
help="Whether to enable the addto strategy for gradient accumulation or not. This is only used for AMP training."
)
parser.add_argument(
"--scale_loss",
type=float,
default=1.0,
help="The value of scale_loss for fp16.")
args = parser.parse_args()
return args
......@@ -149,6 +168,26 @@ def reset_program_state_dict(model, state_dict):
return new_state_dict
def build_compiled_program(main_program, loss):
exec_strategy = paddle.static.ExecutionStrategy()
exec_strategy.num_threads = 1
exec_strategy.num_iteration_per_drop_scope = 10000
build_strategy = paddle.static.BuildStrategy()
build_strategy.enable_addto = args.enable_addto
main_program = paddle.static.CompiledProgram(
main_program).with_data_parallel(
loss_name=loss.name,
exec_strategy=exec_strategy,
build_strategy=build_strategy)
return main_program
def set_seed(seed):
random.seed(seed)
np.random.seed(seed)
paddle.seed(seed)
class WorkerInitObj(object):
def __init__(self, seed):
self.seed = seed
......@@ -158,12 +197,6 @@ class WorkerInitObj(object):
random.seed(self.seed + id)
def set_seed(seed):
random.seed(seed)
np.random.seed(seed)
paddle.seed(seed)
def do_train(args):
# Initialize the paddle and paddle fleet execute enviroment
paddle.enable_static()
......@@ -175,6 +208,8 @@ def do_train(args):
worker_init = WorkerInitObj(args.seed + fleet.worker_index())
# Define the input data in the static mode
main_program = paddle.static.default_main_program()
startup_program = paddle.static.default_startup_program()
data_holders = create_data_holder(args)
[
......@@ -186,9 +221,10 @@ def do_train(args):
args.model_type = args.model_type.lower()
model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path)
model = BertForPretraining(
BertModel(**model_class.pretrained_init_configuration[
args.model_name_or_path]))
config = model_class.pretrained_init_configuration[args.model_name_or_path]
if config["vocab_size"] % 8 != 0:
config["vocab_size"] += 8 - (config["vocab_size"] % 8)
model = BertForPretraining(BertModel(**config))
criterion = BertPretrainingCriterion(model.bert.config["vocab_size"])
prediction_scores, seq_relationship_score = model(
input_ids=input_ids,
......@@ -219,7 +255,14 @@ def do_train(args):
p.name for n, p in model.named_parameters()
if not any(nd in n for nd in ["bias", "norm"])
])
if args.use_amp:
amp_list = paddle.fluid.contrib.mixed_precision.AutoMixedPrecisionLists(
custom_white_list=['softmax'])
optimizer = paddle.fluid.contrib.mixed_precision.decorate(
optimizer,
amp_list,
init_loss_scaling=args.scale_loss,
use_dynamic_loss_scaling=True)
# Use the fleet api to compile the distributed optimizer
strategy = fleet.DistributedStrategy()
optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy)
......@@ -227,13 +270,14 @@ def do_train(args):
# Define the Executor for running the static model
exe = paddle.static.Executor(place)
exe.run(paddle.static.default_startup_program())
exe.run(startup_program)
state_dict = model.state_dict()
# Use the state dict to update the parameter
reset_state_dict = reset_program_state_dict(model, state_dict)
paddle.static.set_program_state(paddle.static.default_main_program(),
reset_state_dict)
paddle.static.set_program_state(main_program, reset_state_dict)
# Construct the compiled program
main_program = build_compiled_program(main_program, loss)
pool = ThreadPoolExecutor(1)
global_step = 0
......@@ -269,9 +313,9 @@ def do_train(args):
for step, batch in enumerate(train_data_loader):
global_step += 1
loss_return = exe.run(paddle.static.default_main_program(),\
feed=batch,
fetch_list=[loss])
loss_return = exe.run(main_program,
feed=batch,
fetch_list=[loss])
# In the new 2.0 api, must call this function to change the learning_rate
lr_scheduler.step()
if global_step % args.logging_steps == 0:
......
......@@ -132,16 +132,11 @@ def parse_args():
type=float,
default=1.0,
help="The value of scale_loss for fp16.")
parser.add_argument(
"--use_dynamic_loss_scaling",
type=distutils.util.strtobool,
default=True,
help="Whether to use dynamic loss scaling.")
args = parser.parse_args()
return args
def construct_compiled_program(main_program, loss):
def build_compiled_program(main_program, loss):
exec_strategy = paddle.static.ExecutionStrategy()
exec_strategy.num_threads = 1
exec_strategy.num_iteration_per_drop_scope = 10000
......@@ -238,7 +233,7 @@ def do_train(args):
optimizer,
amp_list,
init_loss_scaling=args.scale_loss,
use_dynamic_loss_scaling=args.use_dynamic_loss_scaling)
use_dynamic_loss_scaling=True)
optimizer.minimize(loss)
# Define the Executor for running the static model
......@@ -250,7 +245,7 @@ def do_train(args):
reset_state_dict = reset_program_state_dict(model, state_dict)
paddle.static.set_program_state(main_program, reset_state_dict)
# Construct the compiled program
main_program = construct_compiled_program(main_program, loss)
main_program = build_compiled_program(main_program, loss)
global_step = 0
tic_train = time.time()
epoch = 0
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册