提交 028742b6 编写于 作者: 小湉湉's avatar 小湉湉

update lr scheduler

上级 94688264
......@@ -71,9 +71,10 @@ model:
###########################################################
# OPTIMIZER SETTING #
###########################################################
optimizer:
optim: adam # optimizer type
learning_rate: 0.001 # learning rate
scheduler_params:
d_model: 384
warmup_steps: 4000
grad_clip: 1.0
###########################################################
# TRAINING SETTING #
......@@ -84,7 +85,7 @@ num_snapshots: 5
###########################################################
# OTHER SETTING #
###########################################################
seed: 10086
seed: 0
token_list:
- <blank>
......
......@@ -23,8 +23,10 @@ import paddle
import yaml
from paddle import DataParallel
from paddle import distributed as dist
from paddle import nn
from paddle.io import DataLoader
from paddle.io import DistributedBatchSampler
from paddle.optimizer import Adam
from yacs.config import CfgNode
from paddlespeech.t2s.datasets.am_batch_fn import build_erniesat_collate_fn
......@@ -34,7 +36,6 @@ from paddlespeech.t2s.models.ernie_sat import ErnieSATEvaluator
from paddlespeech.t2s.models.ernie_sat import ErnieSATUpdater
from paddlespeech.t2s.training.extensions.snapshot import Snapshot
from paddlespeech.t2s.training.extensions.visualizer import VisualDL
from paddlespeech.t2s.training.optimizer import build_optimizers
from paddlespeech.t2s.training.seeding import seed_everything
from paddlespeech.t2s.training.trainer import Trainer
......@@ -118,12 +119,27 @@ def train_sp(args, config):
odim = config.n_mels
model = ErnieSAT(idim=vocab_size, odim=odim, **config["model"])
# model_path = "/home/yuantian01/PaddleSpeech_ERNIE_SAT/PaddleSpeech/examples/ernie_sat/pretrained_model/paddle_checkpoint_en/model.pdparams"
# state_dict = paddle.load(model_path)
# new_state_dict = {}
# for key, value in state_dict.items():
# new_key = "model." + key
# new_state_dict[new_key] = value
# model.set_state_dict(new_state_dict)
if world_size > 1:
model = DataParallel(model)
print("model done!")
optimizer = build_optimizers(model, **config["optimizer"])
scheduler = paddle.optimizer.lr.NoamDecay(
d_model=config["scheduler_params"]["d_model"],
warmup_steps=config["scheduler_params"]["warmup_steps"])
grad_clip = nn.ClipGradByGlobalNorm(config["grad_clip"])
optimizer = Adam(
learning_rate=scheduler,
grad_clip=grad_clip,
parameters=model.parameters())
print("optimizer done!")
output_dir = Path(args.output_dir)
......@@ -136,6 +152,7 @@ def train_sp(args, config):
updater = ErnieSATUpdater(
model=model,
optimizer=optimizer,
scheduler=scheduler,
dataloader=train_dataloader,
text_masking=config["model"]["text_masking"],
odim=odim,
......
......@@ -18,6 +18,7 @@ from paddle import distributed as dist
from paddle.io import DataLoader
from paddle.nn import Layer
from paddle.optimizer import Optimizer
from paddle.optimizer.lr import LRScheduler
from paddlespeech.t2s.modules.losses import MLMLoss
from paddlespeech.t2s.training.extensions.evaluator import StandardEvaluator
......@@ -34,12 +35,14 @@ class ErnieSATUpdater(StandardUpdater):
def __init__(self,
model: Layer,
optimizer: Optimizer,
scheduler: LRScheduler,
dataloader: DataLoader,
init_state=None,
text_masking: bool=False,
odim: int=80,
output_dir: Path=None):
super().__init__(model, optimizer, dataloader, init_state=None)
self.scheduler = scheduler
self.criterion = MLMLoss(text_masking=text_masking, odim=odim)
......@@ -75,10 +78,12 @@ class ErnieSATUpdater(StandardUpdater):
loss = mlm_loss + text_mlm_loss if text_mlm_loss is not None else mlm_loss
optimizer = self.optimizer
optimizer.clear_grad()
self.optimizer.clear_grad()
loss.backward()
optimizer.step()
self.optimizer.step()
self.scheduler.step()
scheduler_msg = 'lr: {}'.format(self.scheduler.last_lr)
report("train/loss", float(loss))
report("train/mlm_loss", float(mlm_loss))
......@@ -90,6 +95,7 @@ class ErnieSATUpdater(StandardUpdater):
losses_dict["loss"] = float(loss)
self.msg += ', '.join('{}: {:>.6f}'.format(k, v)
for k, v in losses_dict.items())
self.msg += ', ' + scheduler_msg
class ErnieSATEvaluator(StandardEvaluator):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册