diff --git a/PaddleNLP/benchmark/transformer/dygraph/train.py b/PaddleNLP/benchmark/transformer/dygraph/train.py index 51886b408f453ddc7c6d56b7b53e8d9aa8f25c07..bb1a83a77638e41007c32e521d020935f22f20e6 100644 --- a/PaddleNLP/benchmark/transformer/dygraph/train.py +++ b/PaddleNLP/benchmark/transformer/dygraph/train.py @@ -126,14 +126,26 @@ def do_train(args): train_reader_cost = time.time() - batch_start (src_word, trg_word, lbl_word) = input_data - logits = transformer(src_word=src_word, trg_word=trg_word) - - sum_cost, avg_cost, token_num = criterion(logits, lbl_word) - - avg_cost.backward() - - optimizer.step() - optimizer.clear_grad() + if args.use_amp: + scaler = paddle.amp.GradScaler( + init_loss_scaling=args.scale_loss) + with paddle.amp.auto_cast(): + logits = transformer(src_word=src_word, trg_word=trg_word) + sum_cost, avg_cost, token_num = criterion(logits, lbl_word) + + scaled = scaler.scale(avg_cost) # scale the loss + scaled.backward() # do backward + + scaler.minimize(optimizer, scaled) # update parameters + optimizer.clear_grad() + else: + logits = transformer(src_word=src_word, trg_word=trg_word) + sum_cost, avg_cost, token_num = criterion(logits, lbl_word) + + avg_cost.backward() + + optimizer.step() + optimizer.clear_grad() tokens_per_cards = token_num.numpy()