提交 4ac7db18 编写于 作者: 小湉湉's avatar 小湉湉

init for all works in train.py when ngpu>1, test=tts

上级 d6edb62d
......@@ -62,9 +62,3 @@ Contents
:caption: Acknowledgement
asr/reference
......@@ -160,9 +160,8 @@ def train_sp(args, config):
if dist.get_rank() == 0:
trainer.extend(evaluator, trigger=(1, "epoch"))
trainer.extend(VisualDL(output_dir), trigger=(1, "iteration"))
trainer.extend(
Snapshot(max_size=config.num_snapshots), trigger=(1, 'epoch'))
# print(trainer.extensions)
trainer.extend(
Snapshot(max_size=config.num_snapshots), trigger=(1, 'epoch'))
trainer.run()
......
......@@ -231,9 +231,9 @@ def train_sp(args, config):
trainer.extend(
evaluator, trigger=(config.eval_interval_steps, 'iteration'))
trainer.extend(VisualDL(output_dir), trigger=(1, 'iteration'))
trainer.extend(
Snapshot(max_size=config.num_snapshots),
trigger=(config.save_interval_steps, 'iteration'))
trainer.extend(
Snapshot(max_size=config.num_snapshots),
trigger=(config.save_interval_steps, 'iteration'))
print("Trainer Done!")
trainer.run()
......
......@@ -219,9 +219,9 @@ def train_sp(args, config):
trainer.extend(
evaluator, trigger=(config.eval_interval_steps, 'iteration'))
trainer.extend(VisualDL(output_dir), trigger=(1, 'iteration'))
trainer.extend(
Snapshot(max_size=config.num_snapshots),
trigger=(config.save_interval_steps, 'iteration'))
trainer.extend(
Snapshot(max_size=config.num_snapshots),
trigger=(config.save_interval_steps, 'iteration'))
print("Trainer Done!")
trainer.run()
......
......@@ -194,11 +194,10 @@ def train_sp(args, config):
trainer.extend(
evaluator, trigger=(config.eval_interval_steps, 'iteration'))
trainer.extend(VisualDL(output_dir), trigger=(1, 'iteration'))
trainer.extend(
Snapshot(max_size=config.num_snapshots),
trigger=(config.save_interval_steps, 'iteration'))
trainer.extend(
Snapshot(max_size=config.num_snapshots),
trigger=(config.save_interval_steps, 'iteration'))
# print(trainer.extensions.keys())
print("Trainer Done!")
trainer.run()
......
......@@ -212,9 +212,9 @@ def train_sp(args, config):
trainer.extend(
evaluator, trigger=(config.eval_interval_steps, 'iteration'))
trainer.extend(VisualDL(output_dir), trigger=(1, 'iteration'))
trainer.extend(
Snapshot(max_size=config.num_snapshots),
trigger=(config.save_interval_steps, 'iteration'))
trainer.extend(
Snapshot(max_size=config.num_snapshots),
trigger=(config.save_interval_steps, 'iteration'))
print("Trainer Done!")
trainer.run()
......
......@@ -171,8 +171,8 @@ def train_sp(args, config):
if dist.get_rank() == 0:
trainer.extend(evaluator, trigger=(1, "epoch"))
trainer.extend(VisualDL(output_dir), trigger=(1, "iteration"))
trainer.extend(
Snapshot(max_size=config.num_snapshots), trigger=(1, 'epoch'))
trainer.extend(
Snapshot(max_size=config.num_snapshots), trigger=(1, 'epoch'))
trainer.run()
......
......@@ -155,9 +155,8 @@ def train_sp(args, config):
if dist.get_rank() == 0:
trainer.extend(evaluator, trigger=(1, "epoch"))
trainer.extend(VisualDL(output_dir), trigger=(1, "iteration"))
trainer.extend(
Snapshot(max_size=config.num_snapshots), trigger=(1, 'epoch'))
# print(trainer.extensions)
trainer.extend(
Snapshot(max_size=config.num_snapshots), trigger=(1, 'epoch'))
trainer.run()
......
......@@ -148,9 +148,8 @@ def train_sp(args, config):
if dist.get_rank() == 0:
trainer.extend(evaluator, trigger=(1, "epoch"))
trainer.extend(VisualDL(output_dir), trigger=(1, "iteration"))
trainer.extend(
Snapshot(max_size=config.num_snapshots), trigger=(1, 'epoch'))
# print(trainer.extensions)
trainer.extend(
Snapshot(max_size=config.num_snapshots), trigger=(1, 'epoch'))
trainer.run()
......
......@@ -168,9 +168,9 @@ def train_sp(args, config):
trainer.extend(
evaluator, trigger=(config.eval_interval_steps, 'iteration'))
trainer.extend(VisualDL(output_dir), trigger=(1, 'iteration'))
trainer.extend(
Snapshot(max_size=config.num_snapshots),
trigger=(config.save_interval_steps, 'iteration'))
trainer.extend(
Snapshot(max_size=config.num_snapshots),
trigger=(config.save_interval_steps, 'iteration'))
print("Trainer Done!")
trainer.run()
......
......@@ -135,9 +135,8 @@ def train_sp(args, config):
if dist.get_rank() == 0:
trainer.extend(evaluator, trigger=(1, "epoch"))
trainer.extend(VisualDL(output_dir), trigger=(1, "iteration"))
trainer.extend(
Snapshot(max_size=config.num_snapshots), trigger=(1, 'epoch'))
# print(trainer.extensions)
trainer.extend(
Snapshot(max_size=config.num_snapshots), trigger=(1, 'epoch'))
trainer.run()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册