未验证 提交 4059a44d 编写于 作者: Z Zhang Ting 提交者: GitHub

support AMP training (#5067)

* support AMP training
上级 0b8e80b2
...@@ -96,9 +96,10 @@ dropout: 0.1 ...@@ -96,9 +96,10 @@ dropout: 0.1
# Vocabularies in source and target should be same for weight sharing. # Vocabularies in source and target should be same for weight sharing.
weight_sharing: True weight_sharing: True
# Use amp or not # Mixed precision training
use_amp: False use_amp: False
scale_loss: 1.0 use_pure_fp16: False
scale_loss: 128.0
# Whether to use multi-card/multi-node distributed training. # Whether to use multi-card/multi-node distributed training.
# Only works for static graph for now. # Only works for static graph for now.
......
...@@ -96,9 +96,10 @@ dropout: 0.1 ...@@ -96,9 +96,10 @@ dropout: 0.1
# Vocabularies in source and target should be same for weight sharing. # Vocabularies in source and target should be same for weight sharing.
weight_sharing: True weight_sharing: True
# Use amp or not # Mixed precision training
use_amp: False use_amp: False
scale_loss: 1.0 use_pure_fp16: False
scale_loss: 128.0
# Whether to use multi-card/multi-node distributed training. # Whether to use multi-card/multi-node distributed training.
# Only works for static graph for now. # Only works for static graph for now.
......
...@@ -20,6 +20,18 @@ FORMAT = '%(asctime)s-%(levelname)s: %(message)s' ...@@ -20,6 +20,18 @@ FORMAT = '%(asctime)s-%(levelname)s: %(message)s'
logging.basicConfig(level=logging.INFO, format=FORMAT) logging.basicConfig(level=logging.INFO, format=FORMAT)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def cast_parameters_to_fp32(place, program, scope=None):
all_parameters = []
for block in program.blocks:
all_parameters.extend(block.all_parameters())
var_scope = scope if scope else paddle.static.global_scope()
for param in all_parameters:
tensor = var_scope.find_var(param.name).get_tensor()
if 'fp16' in str(tensor._dtype()).lower() and \
'fp32' in str(param.dtype).lower():
data = np.array(tensor)
tensor.set(np.float32(data), place)
def parse_args(): def parse_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
...@@ -93,6 +105,10 @@ def do_predict(args): ...@@ -93,6 +105,10 @@ def do_predict(args):
os.path.join(args.init_from_params, "transformer"), exe) os.path.join(args.init_from_params, "transformer"), exe)
print("finish initing model from params from %s" % (args.init_from_params)) print("finish initing model from params from %s" % (args.init_from_params))
# cast weights from fp16 to fp32 after loading
if args.use_pure_fp16:
cast_parameters_to_fp32(place, test_program)
f = open(args.output_file, "w") f = open(args.output_file, "w")
for data in test_loader: for data in test_loader:
finished_sequence, = exe.run(test_program, finished_sequence, = exe.run(test_program,
......
...@@ -114,6 +114,17 @@ def do_train(args): ...@@ -114,6 +114,17 @@ def do_train(args):
optimizer = fleet.distributed_optimizer( optimizer = fleet.distributed_optimizer(
optimizer, strategy=dist_strategy) optimizer, strategy=dist_strategy)
else:
if args.use_amp:
amp_list = paddle.static.amp.AutoMixedPrecisionLists(
custom_white_list=['softmax', 'layer_norm'],
custom_black_list=['lookup_table_v2'])
optimizer = paddle.static.amp.decorate(
optimizer,
amp_list,
init_loss_scaling=args.scale_loss,
use_dynamic_loss_scaling=True,
use_pure_fp16=args.use_pure_fp16)
optimizer.minimize(avg_cost) optimizer.minimize(avg_cost)
if args.is_distributed: if args.is_distributed:
...@@ -130,6 +141,9 @@ def do_train(args): ...@@ -130,6 +141,9 @@ def do_train(args):
exec_strategy=exec_strategy) exec_strategy=exec_strategy)
exe.run(startup_program) exe.run(startup_program)
if not args.is_distributed and args.use_amp:
optimizer.amp_init(places[0])
# the best cross-entropy value with label smoothing # the best cross-entropy value with label smoothing
loss_normalizer = -( loss_normalizer = -(
(1. - args.label_smooth_eps) * np.log( (1. - args.label_smooth_eps) * np.log(
......
...@@ -287,7 +287,7 @@ class TransformerModel(nn.Layer): ...@@ -287,7 +287,7 @@ class TransformerModel(nn.Layer):
trg_pos = paddle.cast( trg_pos = paddle.cast(
trg_word != self.bos_id, dtype="int64") * paddle.arange( trg_word != self.bos_id, dtype="int64") * paddle.arange(
start=0, end=trg_max_len) start=0, end=trg_max_len)
with paddle.static.amp.fp16_guard():
src_emb = self.src_word_embedding(src_word) src_emb = self.src_word_embedding(src_word)
src_pos_emb = self.src_pos_embedding(src_pos) src_pos_emb = self.src_pos_embedding(src_pos)
src_emb = src_emb + src_pos_emb src_emb = src_emb + src_pos_emb
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册