未验证 提交 938ae21f 编写于 作者: F furnace 提交者: GitHub

add dygraph amp support for transformer (#5187)

* add dygraph amp support for transformer

* according params use_amp to decide use dygraph amp
上级 969939e7
......@@ -126,14 +126,26 @@ def do_train(args):
train_reader_cost = time.time() - batch_start
(src_word, trg_word, lbl_word) = input_data
logits = transformer(src_word=src_word, trg_word=trg_word)
sum_cost, avg_cost, token_num = criterion(logits, lbl_word)
avg_cost.backward()
optimizer.step()
optimizer.clear_grad()
if args.use_amp:
scaler = paddle.amp.GradScaler(
init_loss_scaling=args.scale_loss)
with paddle.amp.auto_cast():
logits = transformer(src_word=src_word, trg_word=trg_word)
sum_cost, avg_cost, token_num = criterion(logits, lbl_word)
scaled = scaler.scale(avg_cost) # scale the loss
scaled.backward() # do backward
scaler.minimize(optimizer, scaled) # update parameters
optimizer.clear_grad()
else:
logits = transformer(src_word=src_word, trg_word=trg_word)
sum_cost, avg_cost, token_num = criterion(logits, lbl_word)
avg_cost.backward()
optimizer.step()
optimizer.clear_grad()
tokens_per_cards = token_num.numpy()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册