diff --git a/PaddleNLP/benchmark/transformer/configs/transformer.base.yaml b/PaddleNLP/benchmark/transformer/configs/transformer.base.yaml index f216c8abe3888e38cf9c51d357ed6dd128cfc9b1..5f162be873002f3eaf205ce42cc990d19c198fd0 100644 --- a/PaddleNLP/benchmark/transformer/configs/transformer.base.yaml +++ b/PaddleNLP/benchmark/transformer/configs/transformer.base.yaml @@ -96,9 +96,10 @@ dropout: 0.1 # Vocabularies in source and target should be same for weight sharing. weight_sharing: True -# Use amp or not +# Mixed precision training use_amp: False -scale_loss: 1.0 +use_pure_fp16: False +scale_loss: 128.0 # Whether to use multi-card/multi-node distributed training. # Only works for static graph for now. diff --git a/PaddleNLP/benchmark/transformer/configs/transformer.big.yaml b/PaddleNLP/benchmark/transformer/configs/transformer.big.yaml index 4f4cbfdeb14c735da0b05bd57cd20ce13765cbff..cf08c76e28d9024dea5ee170b113be8c12f80bd4 100644 --- a/PaddleNLP/benchmark/transformer/configs/transformer.big.yaml +++ b/PaddleNLP/benchmark/transformer/configs/transformer.big.yaml @@ -96,9 +96,10 @@ dropout: 0.1 # Vocabularies in source and target should be same for weight sharing. weight_sharing: True -# Use amp or not +# Mixed precision training use_amp: False -scale_loss: 1.0 +use_pure_fp16: False +scale_loss: 128.0 # Whether to use multi-card/multi-node distributed training. # Only works for static graph for now. diff --git a/PaddleNLP/benchmark/transformer/static/predict.py b/PaddleNLP/benchmark/transformer/static/predict.py index 0ba42e6e02aae7f53ed054b6bedc7df2f0e015b7..245e690e2d0f118767a5c02605ad14c787383a67 100644 --- a/PaddleNLP/benchmark/transformer/static/predict.py +++ b/PaddleNLP/benchmark/transformer/static/predict.py @@ -20,6 +20,18 @@ FORMAT = '%(asctime)s-%(levelname)s: %(message)s' logging.basicConfig(level=logging.INFO, format=FORMAT) logger = logging.getLogger(__name__) +def cast_parameters_to_fp32(place, program, scope=None): + all_parameters = [] + for block in program.blocks: + all_parameters.extend(block.all_parameters()) + + var_scope = scope if scope else paddle.static.global_scope() + for param in all_parameters: + tensor = var_scope.find_var(param.name).get_tensor() + if 'fp16' in str(tensor._dtype()).lower() and \ + 'fp32' in str(param.dtype).lower(): + data = np.array(tensor) + tensor.set(np.float32(data), place) def parse_args(): parser = argparse.ArgumentParser() @@ -93,6 +105,10 @@ def do_predict(args): os.path.join(args.init_from_params, "transformer"), exe) print("finish initing model from params from %s" % (args.init_from_params)) + # cast weights from fp16 to fp32 after loading + if args.use_pure_fp16: + cast_parameters_to_fp32(place, test_program) + f = open(args.output_file, "w") for data in test_loader: finished_sequence, = exe.run(test_program, diff --git a/PaddleNLP/benchmark/transformer/static/train.py b/PaddleNLP/benchmark/transformer/static/train.py index 23d5a332e43b1706091c2fefbedded3ccc00de82..45904f7de60052518b150943cc5820218e2abc6c 100644 --- a/PaddleNLP/benchmark/transformer/static/train.py +++ b/PaddleNLP/benchmark/transformer/static/train.py @@ -114,6 +114,17 @@ def do_train(args): optimizer = fleet.distributed_optimizer( optimizer, strategy=dist_strategy) + else: + if args.use_amp: + amp_list = paddle.static.amp.AutoMixedPrecisionLists( + custom_white_list=['softmax', 'layer_norm'], + custom_black_list=['lookup_table_v2']) + optimizer = paddle.static.amp.decorate( + optimizer, + amp_list, + init_loss_scaling=args.scale_loss, + use_dynamic_loss_scaling=True, + use_pure_fp16=args.use_pure_fp16) optimizer.minimize(avg_cost) if args.is_distributed: @@ -130,6 +141,9 @@ def do_train(args): exec_strategy=exec_strategy) exe.run(startup_program) + if not args.is_distributed and args.use_amp: + optimizer.amp_init(places[0]) + # the best cross-entropy value with label smoothing loss_normalizer = -( (1. - args.label_smooth_eps) * np.log( diff --git a/PaddleNLP/paddlenlp/transformers/transformer/modeling.py b/PaddleNLP/paddlenlp/transformers/transformer/modeling.py index 76d91b46c6989c822a096b143bf80c2560a8e07e..a3d25fffff6c16475c9237d76a171a4aaf0087bf 100644 --- a/PaddleNLP/paddlenlp/transformers/transformer/modeling.py +++ b/PaddleNLP/paddlenlp/transformers/transformer/modeling.py @@ -287,29 +287,29 @@ class TransformerModel(nn.Layer): trg_pos = paddle.cast( trg_word != self.bos_id, dtype="int64") * paddle.arange( start=0, end=trg_max_len) - - src_emb = self.src_word_embedding(src_word) - src_pos_emb = self.src_pos_embedding(src_pos) - src_emb = src_emb + src_pos_emb - enc_input = F.dropout( - src_emb, p=self.dropout, - training=self.training) if self.dropout else src_emb - - trg_emb = self.trg_word_embedding(trg_word) - trg_pos_emb = self.trg_pos_embedding(trg_pos) - trg_emb = trg_emb + trg_pos_emb - dec_input = F.dropout( - trg_emb, p=self.dropout, - training=self.training) if self.dropout else trg_emb - - dec_output = self.transformer( - enc_input, - dec_input, - src_mask=src_slf_attn_bias, - tgt_mask=trg_slf_attn_bias, - memory_mask=trg_src_attn_bias) - - predict = self.linear(dec_output) + with paddle.static.amp.fp16_guard(): + src_emb = self.src_word_embedding(src_word) + src_pos_emb = self.src_pos_embedding(src_pos) + src_emb = src_emb + src_pos_emb + enc_input = F.dropout( + src_emb, p=self.dropout, + training=self.training) if self.dropout else src_emb + + trg_emb = self.trg_word_embedding(trg_word) + trg_pos_emb = self.trg_pos_embedding(trg_pos) + trg_emb = trg_emb + trg_pos_emb + dec_input = F.dropout( + trg_emb, p=self.dropout, + training=self.training) if self.dropout else trg_emb + + dec_output = self.transformer( + enc_input, + dec_input, + src_mask=src_slf_attn_bias, + tgt_mask=trg_slf_attn_bias, + memory_mask=trg_src_attn_bias) + + predict = self.linear(dec_output) return predict