提交 028742b6 编写于 作者: 小湉湉's avatar 小湉湉

update lr scheduler

上级 94688264
...@@ -71,9 +71,10 @@ model: ...@@ -71,9 +71,10 @@ model:
########################################################### ###########################################################
# OPTIMIZER SETTING # # OPTIMIZER SETTING #
########################################################### ###########################################################
optimizer: scheduler_params:
optim: adam # optimizer type d_model: 384
learning_rate: 0.001 # learning rate warmup_steps: 4000
grad_clip: 1.0
########################################################### ###########################################################
# TRAINING SETTING # # TRAINING SETTING #
...@@ -84,7 +85,7 @@ num_snapshots: 5 ...@@ -84,7 +85,7 @@ num_snapshots: 5
########################################################### ###########################################################
# OTHER SETTING # # OTHER SETTING #
########################################################### ###########################################################
seed: 10086 seed: 0
token_list: token_list:
- <blank> - <blank>
......
...@@ -23,8 +23,10 @@ import paddle ...@@ -23,8 +23,10 @@ import paddle
import yaml import yaml
from paddle import DataParallel from paddle import DataParallel
from paddle import distributed as dist from paddle import distributed as dist
from paddle import nn
from paddle.io import DataLoader from paddle.io import DataLoader
from paddle.io import DistributedBatchSampler from paddle.io import DistributedBatchSampler
from paddle.optimizer import Adam
from yacs.config import CfgNode from yacs.config import CfgNode
from paddlespeech.t2s.datasets.am_batch_fn import build_erniesat_collate_fn from paddlespeech.t2s.datasets.am_batch_fn import build_erniesat_collate_fn
...@@ -34,7 +36,6 @@ from paddlespeech.t2s.models.ernie_sat import ErnieSATEvaluator ...@@ -34,7 +36,6 @@ from paddlespeech.t2s.models.ernie_sat import ErnieSATEvaluator
from paddlespeech.t2s.models.ernie_sat import ErnieSATUpdater from paddlespeech.t2s.models.ernie_sat import ErnieSATUpdater
from paddlespeech.t2s.training.extensions.snapshot import Snapshot from paddlespeech.t2s.training.extensions.snapshot import Snapshot
from paddlespeech.t2s.training.extensions.visualizer import VisualDL from paddlespeech.t2s.training.extensions.visualizer import VisualDL
from paddlespeech.t2s.training.optimizer import build_optimizers
from paddlespeech.t2s.training.seeding import seed_everything from paddlespeech.t2s.training.seeding import seed_everything
from paddlespeech.t2s.training.trainer import Trainer from paddlespeech.t2s.training.trainer import Trainer
...@@ -118,12 +119,27 @@ def train_sp(args, config): ...@@ -118,12 +119,27 @@ def train_sp(args, config):
odim = config.n_mels odim = config.n_mels
model = ErnieSAT(idim=vocab_size, odim=odim, **config["model"]) model = ErnieSAT(idim=vocab_size, odim=odim, **config["model"])
# model_path = "/home/yuantian01/PaddleSpeech_ERNIE_SAT/PaddleSpeech/examples/ernie_sat/pretrained_model/paddle_checkpoint_en/model.pdparams"
# state_dict = paddle.load(model_path)
# new_state_dict = {}
# for key, value in state_dict.items():
# new_key = "model." + key
# new_state_dict[new_key] = value
# model.set_state_dict(new_state_dict)
if world_size > 1: if world_size > 1:
model = DataParallel(model) model = DataParallel(model)
print("model done!") print("model done!")
optimizer = build_optimizers(model, **config["optimizer"]) scheduler = paddle.optimizer.lr.NoamDecay(
d_model=config["scheduler_params"]["d_model"],
warmup_steps=config["scheduler_params"]["warmup_steps"])
grad_clip = nn.ClipGradByGlobalNorm(config["grad_clip"])
optimizer = Adam(
learning_rate=scheduler,
grad_clip=grad_clip,
parameters=model.parameters())
print("optimizer done!") print("optimizer done!")
output_dir = Path(args.output_dir) output_dir = Path(args.output_dir)
...@@ -136,6 +152,7 @@ def train_sp(args, config): ...@@ -136,6 +152,7 @@ def train_sp(args, config):
updater = ErnieSATUpdater( updater = ErnieSATUpdater(
model=model, model=model,
optimizer=optimizer, optimizer=optimizer,
scheduler=scheduler,
dataloader=train_dataloader, dataloader=train_dataloader,
text_masking=config["model"]["text_masking"], text_masking=config["model"]["text_masking"],
odim=odim, odim=odim,
......
...@@ -18,6 +18,7 @@ from paddle import distributed as dist ...@@ -18,6 +18,7 @@ from paddle import distributed as dist
from paddle.io import DataLoader from paddle.io import DataLoader
from paddle.nn import Layer from paddle.nn import Layer
from paddle.optimizer import Optimizer from paddle.optimizer import Optimizer
from paddle.optimizer.lr import LRScheduler
from paddlespeech.t2s.modules.losses import MLMLoss from paddlespeech.t2s.modules.losses import MLMLoss
from paddlespeech.t2s.training.extensions.evaluator import StandardEvaluator from paddlespeech.t2s.training.extensions.evaluator import StandardEvaluator
...@@ -34,12 +35,14 @@ class ErnieSATUpdater(StandardUpdater): ...@@ -34,12 +35,14 @@ class ErnieSATUpdater(StandardUpdater):
def __init__(self, def __init__(self,
model: Layer, model: Layer,
optimizer: Optimizer, optimizer: Optimizer,
scheduler: LRScheduler,
dataloader: DataLoader, dataloader: DataLoader,
init_state=None, init_state=None,
text_masking: bool=False, text_masking: bool=False,
odim: int=80, odim: int=80,
output_dir: Path=None): output_dir: Path=None):
super().__init__(model, optimizer, dataloader, init_state=None) super().__init__(model, optimizer, dataloader, init_state=None)
self.scheduler = scheduler
self.criterion = MLMLoss(text_masking=text_masking, odim=odim) self.criterion = MLMLoss(text_masking=text_masking, odim=odim)
...@@ -75,10 +78,12 @@ class ErnieSATUpdater(StandardUpdater): ...@@ -75,10 +78,12 @@ class ErnieSATUpdater(StandardUpdater):
loss = mlm_loss + text_mlm_loss if text_mlm_loss is not None else mlm_loss loss = mlm_loss + text_mlm_loss if text_mlm_loss is not None else mlm_loss
optimizer = self.optimizer self.optimizer.clear_grad()
optimizer.clear_grad()
loss.backward() loss.backward()
optimizer.step() self.optimizer.step()
self.scheduler.step()
scheduler_msg = 'lr: {}'.format(self.scheduler.last_lr)
report("train/loss", float(loss)) report("train/loss", float(loss))
report("train/mlm_loss", float(mlm_loss)) report("train/mlm_loss", float(mlm_loss))
...@@ -90,6 +95,7 @@ class ErnieSATUpdater(StandardUpdater): ...@@ -90,6 +95,7 @@ class ErnieSATUpdater(StandardUpdater):
losses_dict["loss"] = float(loss) losses_dict["loss"] = float(loss)
self.msg += ', '.join('{}: {:>.6f}'.format(k, v) self.msg += ', '.join('{}: {:>.6f}'.format(k, v)
for k, v in losses_dict.items()) for k, v in losses_dict.items())
self.msg += ', ' + scheduler_msg
class ErnieSATEvaluator(StandardEvaluator): class ErnieSATEvaluator(StandardEvaluator):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册