“c1349d98aa48060b449c4eea4dfc95a2989ad203”上不存在“paddle/phi/kernels/gpu/cross_kernel.cu”
未验证 提交 c265768d 编写于 作者: W wawltor 提交者: GitHub

Add the support of amp for the fleet train (#4996)

* add the support of amp for the fleet train
* remove the attribute for the amp
上级 8c076ce0
...@@ -13,6 +13,8 @@ ...@@ -13,6 +13,8 @@
# limitations under the License. # limitations under the License.
import argparse import argparse
import collections
import itertools
import os import os
import random import random
import time import time
...@@ -21,6 +23,7 @@ from functools import partial ...@@ -21,6 +23,7 @@ from functools import partial
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
import numpy as np import numpy as np
import distutils.util
import paddle import paddle
import paddle.distributed.fleet as fleet import paddle.distributed.fleet as fleet
...@@ -117,6 +120,22 @@ def parse_args(): ...@@ -117,6 +120,22 @@ def parse_args():
help="Save checkpoint every X updates steps.") help="Save checkpoint every X updates steps.")
parser.add_argument( parser.add_argument(
"--seed", type=int, default=42, help="Random seed for initialization") "--seed", type=int, default=42, help="Random seed for initialization")
parser.add_argument(
"--use_amp",
type=distutils.util.strtobool,
default=False,
help="Enable mixed precision training.")
parser.add_argument(
"--enable_addto",
type=distutils.util.strtobool,
default=False,
help="Whether to enable the addto strategy for gradient accumulation or not. This is only used for AMP training."
)
parser.add_argument(
"--scale_loss",
type=float,
default=1.0,
help="The value of scale_loss for fp16.")
args = parser.parse_args() args = parser.parse_args()
return args return args
...@@ -149,6 +168,26 @@ def reset_program_state_dict(model, state_dict): ...@@ -149,6 +168,26 @@ def reset_program_state_dict(model, state_dict):
return new_state_dict return new_state_dict
def build_compiled_program(main_program, loss):
exec_strategy = paddle.static.ExecutionStrategy()
exec_strategy.num_threads = 1
exec_strategy.num_iteration_per_drop_scope = 10000
build_strategy = paddle.static.BuildStrategy()
build_strategy.enable_addto = args.enable_addto
main_program = paddle.static.CompiledProgram(
main_program).with_data_parallel(
loss_name=loss.name,
exec_strategy=exec_strategy,
build_strategy=build_strategy)
return main_program
def set_seed(seed):
random.seed(seed)
np.random.seed(seed)
paddle.seed(seed)
class WorkerInitObj(object): class WorkerInitObj(object):
def __init__(self, seed): def __init__(self, seed):
self.seed = seed self.seed = seed
...@@ -158,12 +197,6 @@ class WorkerInitObj(object): ...@@ -158,12 +197,6 @@ class WorkerInitObj(object):
random.seed(self.seed + id) random.seed(self.seed + id)
def set_seed(seed):
random.seed(seed)
np.random.seed(seed)
paddle.seed(seed)
def do_train(args): def do_train(args):
# Initialize the paddle and paddle fleet execute enviroment # Initialize the paddle and paddle fleet execute enviroment
paddle.enable_static() paddle.enable_static()
...@@ -175,6 +208,8 @@ def do_train(args): ...@@ -175,6 +208,8 @@ def do_train(args):
worker_init = WorkerInitObj(args.seed + fleet.worker_index()) worker_init = WorkerInitObj(args.seed + fleet.worker_index())
# Define the input data in the static mode # Define the input data in the static mode
main_program = paddle.static.default_main_program()
startup_program = paddle.static.default_startup_program()
data_holders = create_data_holder(args) data_holders = create_data_holder(args)
[ [
...@@ -186,9 +221,10 @@ def do_train(args): ...@@ -186,9 +221,10 @@ def do_train(args):
args.model_type = args.model_type.lower() args.model_type = args.model_type.lower()
model_class, tokenizer_class = MODEL_CLASSES[args.model_type] model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path) tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path)
model = BertForPretraining( config = model_class.pretrained_init_configuration[args.model_name_or_path]
BertModel(**model_class.pretrained_init_configuration[ if config["vocab_size"] % 8 != 0:
args.model_name_or_path])) config["vocab_size"] += 8 - (config["vocab_size"] % 8)
model = BertForPretraining(BertModel(**config))
criterion = BertPretrainingCriterion(model.bert.config["vocab_size"]) criterion = BertPretrainingCriterion(model.bert.config["vocab_size"])
prediction_scores, seq_relationship_score = model( prediction_scores, seq_relationship_score = model(
input_ids=input_ids, input_ids=input_ids,
...@@ -219,7 +255,14 @@ def do_train(args): ...@@ -219,7 +255,14 @@ def do_train(args):
p.name for n, p in model.named_parameters() p.name for n, p in model.named_parameters()
if not any(nd in n for nd in ["bias", "norm"]) if not any(nd in n for nd in ["bias", "norm"])
]) ])
if args.use_amp:
amp_list = paddle.fluid.contrib.mixed_precision.AutoMixedPrecisionLists(
custom_white_list=['softmax'])
optimizer = paddle.fluid.contrib.mixed_precision.decorate(
optimizer,
amp_list,
init_loss_scaling=args.scale_loss,
use_dynamic_loss_scaling=True)
# Use the fleet api to compile the distributed optimizer # Use the fleet api to compile the distributed optimizer
strategy = fleet.DistributedStrategy() strategy = fleet.DistributedStrategy()
optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy) optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy)
...@@ -227,13 +270,14 @@ def do_train(args): ...@@ -227,13 +270,14 @@ def do_train(args):
# Define the Executor for running the static model # Define the Executor for running the static model
exe = paddle.static.Executor(place) exe = paddle.static.Executor(place)
exe.run(paddle.static.default_startup_program()) exe.run(startup_program)
state_dict = model.state_dict() state_dict = model.state_dict()
# Use the state dict to update the parameter # Use the state dict to update the parameter
reset_state_dict = reset_program_state_dict(model, state_dict) reset_state_dict = reset_program_state_dict(model, state_dict)
paddle.static.set_program_state(paddle.static.default_main_program(), paddle.static.set_program_state(main_program, reset_state_dict)
reset_state_dict) # Construct the compiled program
main_program = build_compiled_program(main_program, loss)
pool = ThreadPoolExecutor(1) pool = ThreadPoolExecutor(1)
global_step = 0 global_step = 0
...@@ -269,7 +313,7 @@ def do_train(args): ...@@ -269,7 +313,7 @@ def do_train(args):
for step, batch in enumerate(train_data_loader): for step, batch in enumerate(train_data_loader):
global_step += 1 global_step += 1
loss_return = exe.run(paddle.static.default_main_program(),\ loss_return = exe.run(main_program,
feed=batch, feed=batch,
fetch_list=[loss]) fetch_list=[loss])
# In the new 2.0 api, must call this function to change the learning_rate # In the new 2.0 api, must call this function to change the learning_rate
......
...@@ -132,16 +132,11 @@ def parse_args(): ...@@ -132,16 +132,11 @@ def parse_args():
type=float, type=float,
default=1.0, default=1.0,
help="The value of scale_loss for fp16.") help="The value of scale_loss for fp16.")
parser.add_argument(
"--use_dynamic_loss_scaling",
type=distutils.util.strtobool,
default=True,
help="Whether to use dynamic loss scaling.")
args = parser.parse_args() args = parser.parse_args()
return args return args
def construct_compiled_program(main_program, loss): def build_compiled_program(main_program, loss):
exec_strategy = paddle.static.ExecutionStrategy() exec_strategy = paddle.static.ExecutionStrategy()
exec_strategy.num_threads = 1 exec_strategy.num_threads = 1
exec_strategy.num_iteration_per_drop_scope = 10000 exec_strategy.num_iteration_per_drop_scope = 10000
...@@ -238,7 +233,7 @@ def do_train(args): ...@@ -238,7 +233,7 @@ def do_train(args):
optimizer, optimizer,
amp_list, amp_list,
init_loss_scaling=args.scale_loss, init_loss_scaling=args.scale_loss,
use_dynamic_loss_scaling=args.use_dynamic_loss_scaling) use_dynamic_loss_scaling=True)
optimizer.minimize(loss) optimizer.minimize(loss)
# Define the Executor for running the static model # Define the Executor for running the static model
...@@ -250,7 +245,7 @@ def do_train(args): ...@@ -250,7 +245,7 @@ def do_train(args):
reset_state_dict = reset_program_state_dict(model, state_dict) reset_state_dict = reset_program_state_dict(model, state_dict)
paddle.static.set_program_state(main_program, reset_state_dict) paddle.static.set_program_state(main_program, reset_state_dict)
# Construct the compiled program # Construct the compiled program
main_program = construct_compiled_program(main_program, loss) main_program = build_compiled_program(main_program, loss)
global_step = 0 global_step = 0
tic_train = time.time() tic_train = time.time()
epoch = 0 epoch = 0
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册