From 938ae21fc5a2c9631fe192b4a6a5c762fa438a7c Mon Sep 17 00:00:00 2001 From: furnace <34057289+windstamp@users.noreply.github.com> Date: Mon, 11 Jan 2021 16:43:28 +0800 Subject: [PATCH] add dygraph amp support for transformer (#5187) * add dygraph amp support for transformer * according params use_amp to decide use dygraph amp --- .../benchmark/transformer/dygraph/train.py | 28 +++++++++++++------ 1 file changed, 20 insertions(+), 8 deletions(-) diff --git a/PaddleNLP/benchmark/transformer/dygraph/train.py b/PaddleNLP/benchmark/transformer/dygraph/train.py index 51886b40..bb1a83a7 100644 --- a/PaddleNLP/benchmark/transformer/dygraph/train.py +++ b/PaddleNLP/benchmark/transformer/dygraph/train.py @@ -126,14 +126,26 @@ def do_train(args): train_reader_cost = time.time() - batch_start (src_word, trg_word, lbl_word) = input_data - logits = transformer(src_word=src_word, trg_word=trg_word) - - sum_cost, avg_cost, token_num = criterion(logits, lbl_word) - - avg_cost.backward() - - optimizer.step() - optimizer.clear_grad() + if args.use_amp: + scaler = paddle.amp.GradScaler( + init_loss_scaling=args.scale_loss) + with paddle.amp.auto_cast(): + logits = transformer(src_word=src_word, trg_word=trg_word) + sum_cost, avg_cost, token_num = criterion(logits, lbl_word) + + scaled = scaler.scale(avg_cost) # scale the loss + scaled.backward() # do backward + + scaler.minimize(optimizer, scaled) # update parameters + optimizer.clear_grad() + else: + logits = transformer(src_word=src_word, trg_word=trg_word) + sum_cost, avg_cost, token_num = criterion(logits, lbl_word) + + avg_cost.backward() + + optimizer.step() + optimizer.clear_grad() tokens_per_cards = token_num.numpy() -- GitLab