diff --git a/examples/csmsc/vits/conf/default.yaml b/examples/csmsc/vits/conf/default.yaml index a2aef998d2843bc4330339120304bf8741601bf3..7e9e9c1d9831b2743ac97dea1d1a4af9043bdba9 100644 --- a/examples/csmsc/vits/conf/default.yaml +++ b/examples/csmsc/vits/conf/default.yaml @@ -179,7 +179,7 @@ generator_first: False # whether to start updating generator first # OTHER TRAINING SETTING # ########################################################## num_snapshots: 10 # max number of snapshots to keep while training -train_max_steps: 350000 # Number of training steps. == total_iters / ngpus, total_iters = 1000000 -save_interval_steps: 1000 # Interval steps to save checkpoint. -eval_interval_steps: 250 # Interval steps to evaluate the network. +max_epoch: 1000 # Number of training epochs. +save_interval_epochs: 1 # Interval epochs to save checkpoint. +eval_interval_epochs: 1 # Interval steps to evaluate the network. seed: 777 # random seed number diff --git a/paddlespeech/t2s/exps/vits/train.py b/paddlespeech/t2s/exps/vits/train.py index 0e74bf631d14ebff622b643a4e9f5e5804d63339..cdfd300344871e8f9c149bca7ad149bfaa960254 100644 --- a/paddlespeech/t2s/exps/vits/train.py +++ b/paddlespeech/t2s/exps/vits/train.py @@ -230,17 +230,15 @@ def train_sp(args, config): output_dir=output_dir) trainer = Trainer( - updater, - stop_trigger=(config.train_max_steps, "iteration"), - out=output_dir) + updater, stop_trigger=(config.max_epoch, 'epoch'), out=output_dir) if dist.get_rank() == 0: trainer.extend( - evaluator, trigger=(config.eval_interval_steps, 'iteration')) + evaluator, trigger=(config.eval_interval_epochs, 'epoch')) trainer.extend(VisualDL(output_dir), trigger=(1, 'iteration')) trainer.extend( Snapshot(max_size=config.num_snapshots), - trigger=(config.save_interval_steps, 'iteration')) + trigger=(config.save_interval_epochs, 'epoch')) print("Trainer Done!") trainer.run() diff --git a/paddlespeech/t2s/models/vits/vits_updater.py b/paddlespeech/t2s/models/vits/vits_updater.py index 9f8be68034e946514a972c2720d6e0fe871f8877..e61e617cc8f837f39ae0269caf6edf47932e023a 100644 --- a/paddlespeech/t2s/models/vits/vits_updater.py +++ b/paddlespeech/t2s/models/vits/vits_updater.py @@ -166,7 +166,9 @@ class VITSUpdater(StandardUpdater): gen_loss.backward() self.optimizer_g.step() - self.scheduler_g.step() + # learning rate updates on each epoch. + if self.state.iteration % self.updates_per_epoch == 0: + self.scheduler_g.step() # reset cache if self.model.reuse_cache_gen or not self.model.training: @@ -202,7 +204,9 @@ class VITSUpdater(StandardUpdater): dis_loss.backward() self.optimizer_d.step() - self.scheduler_d.step() + # learning rate updates on each epoch. + if self.state.iteration % self.updates_per_epoch == 0: + self.scheduler_d.step() # reset cache if self.model.reuse_cache_dis or not self.model.training: